In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, in_channels, nf,latent_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, nf, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(nf, nf*2, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(nf*2, nf*4, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = nn.Conv2d(nf*4, latent_dim, kernel_size=3, padding=1)

    def forward(self, x):
        x1 = F.relu(self.conv1(x))
        x2 = F.relu(self.conv2(self.pool(x1)))
        x3 = F.relu(self.conv3(self.pool(x2)))
        z = self.fc(x3)
        return z, [x1, x2, x3]

class Decoder(nn.Module):
    def __init__(self, latent_dim, nf, out_channels):
        super(Decoder, self).__init__()
        self.fc = nn.Conv2d(latent_dim, nf*4, kernel_size=3, padding=1)
        self.deconv1 = nn.ConvTranspose2d(nf*4, nf*2, kernel_size=3, padding=1)
        self.deconv2 = nn.ConvTranspose2d(nf*2, nf, kernel_size=3, padding=1)
        self.deconv3 = nn.ConvTranspose2d(nf, out_channels, kernel_size=3, padding=1)
        self.upsample = nn.Upsample(scale_factor=2)

    def forward(self, z):
        x = self.fc(z)
        x = F.relu(self.deconv1(self.upsample(x)))
        x = F.relu(self.deconv2(self.upsample(x)))
        x_hat =self.deconv3(x)
        return x_hat

class Regressor(nn.Module):
    def __init__(self, latent_dim, nf, out_channels):
        super(Regressor, self).__init__()
        self.fc = nn.Conv2d(latent_dim, nf*4, kernel_size=3, padding=1)
        self.conv1 = nn.ConvTranspose2d(nf*4+nf*2, nf*2, kernel_size=3, padding=1)
        self.conv2 = nn.ConvTranspose2d(nf*2+nf, nf, kernel_size=3, padding=1)
        self.conv3 = nn.ConvTranspose2d(nf, out_channels, kernel_size=3, padding=1)
        self.upsample = nn.Upsample(scale_factor=2)

    def forward(self, z, intermediate_outputs):
        x1, x2, x3 = intermediate_outputs
        x = self.fc(z)
        x = self.upsample(x)
        #print(x.shape,x2.shape)
        x = torch.concat([x, x2], dim=1)
        x = F.relu(self.conv1(x))
        x = self.upsample(x)
        x = torch.concat([x, x1], dim=1)
        #print(x.shape,'[x,x1]')
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        y_hat = x
        return y_hat

class HybridModel(nn.Module):
    def __init__(self, in_channels, nf, latent_dim, rec_channels, out_channels):
        super(HybridModel, self).__init__()
        self.encoder = Encoder(in_channels, nf, latent_dim)
        self.decoder = Decoder(latent_dim, nf, rec_channels)
        self.regressor = Regressor(latent_dim, nf, out_channels)

    def forward(self, x):
        x = F.interpolate(x, size=(152, 48), mode='bilinear', align_corners=False)
        z, intermediate_outputs = self.encoder(x)
        x_hat = self.decoder(z)
        x_hat = F.interpolate(x_hat, size=(150, 49), mode='bilinear', align_corners=False)
        y_hat = self.regressor(z, intermediate_outputs)
        y_hat = F.interpolate(y_hat, size=(150, 49), mode='bilinear', align_corners=False)
        return x_hat, y_hat

# Example usage
in_channels = 29
rec_channels = 13
latent_dim = 128
out_channels = 9
#model = HybridModel(in_channels, latent_dim, out_channels)
nx = 128
ny = 48
nf = 16
model_encoder = Encoder(in_channels, nf, latent_dim)
model_decoder = Decoder(latent_dim, nf, out_channels)
# Example input
x = torch.randn(1, in_channels, 128, 64)
#x_hat = model(x)
z, intermediate_outputs = model_encoder(x)

#print(x_hat.shape)
#print(z.shape)
#for x in intermediate_outputs:
#    print('int_shape',x.shape)

hybrid_model = HybridModel(in_channels, nf, latent_dim, rec_channels,out_channels)
torch.onnx.export(hybrid_model, x, "hybrid_model.onnx", verbose=True, input_names = ['input'], output_names = ['output1', 'output2'])

Exported graph: graph(%input : Float(1, 29, 128, 64, strides=[237568, 8192, 64, 1], requires_grad=0, device=cpu),
      %encoder.conv1.weight : Float(16, 29, 3, 3, strides=[261, 9, 3, 1], requires_grad=1, device=cpu),
      %encoder.conv1.bias : Float(16, strides=[1], requires_grad=1, device=cpu),
      %encoder.conv2.weight : Float(32, 16, 3, 3, strides=[144, 9, 3, 1], requires_grad=1, device=cpu),
      %encoder.conv2.bias : Float(32, strides=[1], requires_grad=1, device=cpu),
      %encoder.conv3.weight : Float(64, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=1, device=cpu),
      %encoder.conv3.bias : Float(64, strides=[1], requires_grad=1, device=cpu),
      %encoder.fc.weight : Float(128, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=1, device=cpu),
      %encoder.fc.bias : Float(128, strides=[1], requires_grad=1, device=cpu),
      %decoder.fc.weight : Float(64, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cpu),
      %decoder.fc.bias : Float(64, strides=[1],

In [2]:
import netCDF4 as nc
import numpy as np
with nc.Dataset("training_data_150_ocean.nc") as f:
    X=f['X'][:]
    y=f['y'][:]

In [5]:
import torch
n_all=X.shape[0]
nt=int(n_all*0.8)
print(X.shape)
print(y.shape)
X_torch=torch.tensor(X[:,:,:,:],dtype=torch.float32)
y_torch=torch.tensor(y[:,:,:,:],dtype=torch.float32)


X_torch=X_torch.permute(0,3,1,2)
y_torch=y_torch.permute(0,3,1,2)
#mask_torch=torch.tensor(mask_y[:nt],dtype=torch.float32) 
model = HybridModel(in_channels=19, nf=16, latent_dim=16, rec_channels=13,out_channels=12)
dataset = torch.utils.data.TensorDataset(X_torch[:nt], y_torch[:nt])
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

(4501, 150, 49, 19)
(4501, 150, 49, 12)


In [6]:
n_epochs = 20
for epoch in range(n_epochs):
    avg_loss = 0
    for inputs, target in dataloader:
        optimizer.zero_grad()
        outputs, y_ = model(inputs)
        loss1 = criterion(outputs, inputs[:,:13,:,:])
        loss2 = criterion(y_, target)
        loss=loss1+loss2
        loss.backward()
        optimizer.step()
        avg_loss += loss.item()
    print(f"Epoch {epoch+1}/{n_epochs}, Loss: {avg_loss/len(dataloader)}")

Epoch 1/20, Loss: 1.019480422007299
Epoch 2/20, Loss: 0.8051752252916319
Epoch 3/20, Loss: 0.5713982268244819
Epoch 4/20, Loss: 0.5322593151204353
Epoch 5/20, Loss: 0.5000425881780355
Epoch 6/20, Loss: 0.5151425774118542
Epoch 7/20, Loss: 0.5155692382723884
Epoch 8/20, Loss: 0.4795691686130203
Epoch 9/20, Loss: 0.46254190591584265
Epoch 10/20, Loss: 0.4445244071230424
Epoch 11/20, Loss: 0.4357502328611053
Epoch 12/20, Loss: 0.4293900540972178
Epoch 13/20, Loss: 0.4210650435591166
Epoch 14/20, Loss: 0.41336016000899595
Epoch 15/20, Loss: 0.4664328032362778
Epoch 16/20, Loss: 0.42880507908036225
Epoch 17/20, Loss: 0.4123967717179155
Epoch 18/20, Loss: 0.4073435810551179
Epoch 19/20, Loss: 0.4015618010168582
Epoch 20/20, Loss: 0.40106637881392926


In [8]:
torch.save(model.state_dict(), "ocean_model_150_20e.pth")
torch.save(model, "ocean_model_150_20e_full.pth")
torch.onnx.export(model, X_torch[0].unsqueeze(0), "ocean_150_model.onnx", verbose=True, input_names = ['input'], output_names = ['output_rec', 'output_pred'])

Exported graph: graph(%input : Float(1, 19, 150, 49, strides=[19, 1, 931, 19], requires_grad=0, device=cpu),
      %encoder.conv1.weight : Float(16, 19, 3, 3, strides=[171, 9, 3, 1], requires_grad=1, device=cpu),
      %encoder.conv1.bias : Float(16, strides=[1], requires_grad=1, device=cpu),
      %encoder.conv2.weight : Float(32, 16, 3, 3, strides=[144, 9, 3, 1], requires_grad=1, device=cpu),
      %encoder.conv2.bias : Float(32, strides=[1], requires_grad=1, device=cpu),
      %encoder.conv3.weight : Float(64, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=1, device=cpu),
      %encoder.conv3.bias : Float(64, strides=[1], requires_grad=1, device=cpu),
      %encoder.fc.weight : Float(16, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=1, device=cpu),
      %encoder.fc.bias : Float(16, strides=[1], requires_grad=1, device=cpu),
      %decoder.fc.weight : Float(64, 16, 3, 3, strides=[144, 9, 3, 1], requires_grad=1, device=cpu),
      %decoder.fc.bias : Float(64, strides=[1], requires

In [9]:
y_rec, y_pred = model(X_torch[nt:])
y_pred_np=y_pred.detach().numpy()
y_target_np=y_torch.detach().numpy()
print(y.shape)
for i in range(12):
    print(np.corrcoef(y_pred_np[:,i,:,:].flatten(),y_target_np[nt:,i,:,:].flatten()))


(4501, 150, 49, 12)
[[1.         0.98485537]
 [0.98485537 1.        ]]
[[1.         0.80281495]
 [0.80281495 1.        ]]
[[1.         0.95289309]
 [0.95289309 1.        ]]
[[1.         0.88243332]
 [0.88243332 1.        ]]
[[1.         0.84299261]
 [0.84299261 1.        ]]
[[1.         0.69640302]
 [0.69640302 1.        ]]
[[1.         0.67005201]
 [0.67005201 1.        ]]
[[1.         0.52503581]
 [0.52503581 1.        ]]
[[1.         0.55565989]
 [0.55565989 1.        ]]
[[1.         0.79750109]
 [0.79750109 1.        ]]
[[1.         0.49641696]
 [0.49641696 1.        ]]
[[1.         0.64207041]
 [0.64207041 1.        ]]


In [None]:

print('split')

In [14]:
#a_land=np.nonzero(np.array(sfc_type_L[nt:])[:,:,:48]>0.01)
i=1


(5397, 128, 49, 12)
[[1.         0.98603103]
 [0.98603103 1.        ]]
[[1.         0.80206456]
 [0.80206456 1.        ]]
[[1.         0.95443157]
 [0.95443157 1.        ]]
[[1.         0.89222258]
 [0.89222258 1.        ]]
[[1.        0.8573486]
 [0.8573486 1.       ]]
[[1.         0.69373519]
 [0.69373519 1.        ]]
[[1.         0.65264072]
 [0.65264072 1.        ]]
[[1.        0.4794228]
 [0.4794228 1.       ]]
[[1.         0.54139542]
 [0.54139542 1.        ]]
[[1.         0.78273988]
 [0.78273988 1.        ]]
[[1.         0.48225828]
 [0.48225828 1.        ]]
[[1.         0.65136369]
 [0.65136369 1.        ]]
