In [1]:
import numpy as np
import einops
from dataset import ClimateHackDataset
from torch.utils.data import DataLoader
import torch
from loss import MS_SSIMLoss
from submission.model import Model
import torch.optim as optim
from torchinfo import summary

data = np.load("./data/data.npz")
dataset = data["data"]
dataset = einops.rearrange(dataset, 'd t y x -> (d t) y x')
dataset = dataset[:720] # (720, 891, 1843)

ch_dataset = ClimateHackDataset(dataset) # (320 (NUM_SEQUENCES), 2 (FEATURES, TARGETS), 12/24, 128/64, 128/64)
ch_dataloader = DataLoader(ch_dataset, batch_size=4) # (80 (NUM_BATCHES), 2 (FEATURES, TARGETS), 4 (BATCH_SIZE), 12/24, 128/64, 128/64)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model().to(device)
optimiser = optim.Adam(model.parameters(), lr=1e-4)
criterion = MS_SSIMLoss(channels=1)

In [2]:
summary(model, input_size=(4, 11, 128, 128))

Layer (type:depth-idx)                   Output Shape              Param #
Model                                    --                        --
├─Conv3d: 1-1                            [4, 4, 9, 63, 63]         504
├─MaxPool3d: 1-2                         [4, 4, 9, 31, 31]         --
├─Conv3d: 1-3                            [4, 16, 9, 16, 16]        1,744
├─MaxPool3d: 1-4                         [4, 16, 9, 8, 8]          --
├─Conv3d: 1-5                            [4, 64, 7, 3, 3]          27,712
├─MaxPool3d: 1-6                         [4, 64, 7, 1, 1]          --
├─Flatten: 1-7                           [4, 448]                  --
├─Linear: 1-8                            [4, 16384]                7,356,416
Total params: 7,386,376
Trainable params: 7,386,376
Non-trainable params: 0
Total mult-adds (M): 124.50
Input size (MB): 2.88
Forward/backward pass size (MB): 6.41
Params size (MB): 29.55
Estimated Total Size (MB): 38.83

In [None]:
losses = []
EPOCHS = 500
loss = torch.nn.MSELoss()
for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}")
    running_loss = 0
    i = 0
    count = 0
    batch_features, batch_targets = next(iter(ch_dataloader)) # [4, 11, 128, 128], [4, 1, 128, 128]
    optimiser.zero_grad()
    batch_predictions = model(batch_features.to(device)) # [4, 1, 128, 128]
    batch_loss = criterion(batch_predictions, batch_targets.to(device))
    batch_loss.backward()
    optimiser.step()
    running_loss += batch_loss.item() * batch_predictions.shape[0]
    count += batch_predictions.shape[0]
    i += 1
    #print(f"Completed batch {i} of epoch {epoch + 1} with loss {batch_loss.item()} -- processed {count} image sequences")
    losses.append(running_loss / count)
    print(f"Loss for epoch {epoch + 1}/{EPOCHS}: {losses[-1]}")

In [5]:
losses = []
EPOCHS = 5
for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}")
    running_loss = 0
    i = 0
    count = 0
    for batch_features, batch_targets in ch_dataloader:
        optimiser.zero_grad()
        batch_predictions = model(batch_features.to(device))
        batch_loss = criterion(batch_predictions.unsqueeze(dim=2), batch_targets.unsqueeze(dim=2).to(device))
        batch_loss.backward()
        optimiser.step()
        running_loss += batch_loss.item() * batch_predictions.shape[0]
        count += batch_predictions.shape[0]
        i += 1
        #print(f"Completed batch {i} of epoch {epoch + 1} with loss {batch_loss.item()} -- processed {count} image sequences")
    losses.append(running_loss / count)
    print(f"Loss for epoch {epoch + 1}/{EPOCHS}: {losses[-1]}")

Epoch 1
Loss for epoch 1/5: 0.09466870029767355
Epoch 2
Loss for epoch 2/5: 0.09407109121481577
Epoch 3
Loss for epoch 3/5: 0.0930822322765986
Epoch 4
Loss for epoch 4/5: 0.09207746535539627
Epoch 5
Loss for epoch 5/5: 0.09105079869429271


In [None]:
torch.save(model.state_dict(), 'submission/model.pt')