In [1]:
import numpy as np
import einops
from dataset import ClimateHackDataset
from torch.utils.data import DataLoader
import torch
from loss import MS_SSIMLoss
from submission.model import Model
import torch.optim as optim
from torchinfo import summary

data = np.load("./data/data.npz")
dataset = data["data"]
dataset = einops.rearrange(dataset, 'd t y x -> (d t) y x')
dataset = dataset[:720] # (720, 891, 1843)

ch_dataset = ClimateHackDataset(dataset) # (320 (NUM_SEQUENCES), 2 (FEATURES, TARGETS), 12/24, 128/64, 128/64)
ch_dataloader = DataLoader(ch_dataset, batch_size=4) # (80 (NUM_BATCHES), 2 (FEATURES, TARGETS), 4 (BATCH_SIZE), 12/24, 128/64, 128/64)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model().to(device)
optimiser = optim.Adam(model.parameters(), lr=1e-4)
criterion = MS_SSIMLoss(channels=1)

In [None]:
summary(model, input_size=(4, 12, 128, 128))

In [None]:
losses = []
EPOCHS = 500
loss = torch.nn.MSELoss()
for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}")
    running_loss = 0
    i = 0
    count = 0
    batch_features, batch_targets = next(iter(ch_dataloader))
    optimiser.zero_grad()
    batch_predictions = model(batch_features.to(device))
    batch_loss = criterion(batch_predictions, batch_targets.to(device))
    batch_loss.backward()
    optimiser.step()
    running_loss += batch_loss.item() * batch_predictions.shape[0]
    count += batch_predictions.shape[0]
    i += 1
    #print(f"Completed batch {i} of epoch {epoch + 1} with loss {batch_loss.item()} -- processed {count} image sequences")
    losses.append(running_loss / count)
    print(f"Loss for epoch {epoch + 1}/{EPOCHS}: {losses[-1]}")

In [None]:
losses = []
EPOCHS = 5
for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}")
    running_loss = 0
    i = 0
    count = 0
    for batch_features, batch_targets in ch_dataloader:
        optimiser.zero_grad()
        batch_predictions = model(batch_features.to(device))
        batch_loss = criterion(batch_predictions.unsqueeze(dim=2), batch_targets.unsqueeze(dim=2).to(device))
        batch_loss.backward()
        optimiser.step()
        running_loss += batch_loss.item() * batch_predictions.shape[0]
        count += batch_predictions.shape[0]
        i += 1
        #print(f"Completed batch {i} of epoch {epoch + 1} with loss {batch_loss.item()} -- processed {count} image sequences")
    losses.append(running_loss / count)
    print(f"Loss for epoch {epoch + 1}/{EPOCHS}: {losses[-1]}")

In [None]:
torch.save(model.state_dict(), 'submission/model.pt')