In [1]:
%pwd

'/home/group.kurse/cviwo013/ComputerVisionProject'

In [2]:
import numpy as np
import os
from pathlib import Path
from PIL import Image
import json
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchvision import transforms as T
import torch
import random
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from CellDataset import CellDataset
from MoCoResNetBackbone import MoCoResNetBackbone
from MoCoV2Loss import MoCoV2Loss

In [3]:
modelPath = Path("/scratch/cv-course-group-5/models/")
trainingName = "training0/"

device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

model = MoCoResNetBackbone()
model.to(device)

dataset = CellDataset()

moco_loss = MoCoV2Loss(device=device)

epochs = 50
batch_size = 64
learning_rate = 0.001
momentum = 0.9

optimizer = torch.optim.SGD(
    model.parameters(),
    lr=learning_rate,
    momentum=momentum
)

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,        # Adjust to CPU core count
    pin_memory=True,       # Enables fast transfer to GPU
)

losses = [[] for _ in range(epochs)]

checkpoint_epoch = 0
while os.path.exists(modelPath / trainingName / f"epoch{checkpoint_epoch + 5}.pth"):
    checkpoint_epoch = checkpoint_epoch + 5

if checkpoint_epoch > 0:
    state_dict = torch.load(modelPath / trainingName / f"epoch{checkpoint_epoch}.pth")
    model.load_state_dict(state_dict)

for epoch in range(checkpoint_epoch + 1, epochs):

    model.train()
    for [keys, queries] in tqdm(dataloader, desc=f"Epoch {epoch}", total=len(dataloader), ncols=100):

        keys = keys.to(device, non_blocking=True)
        queries = queries.to(device, non_blocking=True)

        query_encodings, key_encodings = model(queries, keys)

        loss = moco_loss(query_encodings, key_encodings)
        losses[epoch].append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    if epoch % 5 == 0:
        if not os.path.exists(modelPath / trainingName):
            os.makedirs(modelPath / trainingName)
        torch.save(model.state_dict(), modelPath / trainingName / f"epoch{epoch}.pth")

    print(f"Epoch {epoch} loss: {sum(losses[epoch]) / len(losses[epoch])} ")

101052


Epoch 16: 100%|█████████████████████████████████████████████████| 1579/1579 [04:45<00:00,  5.54it/s]


Epoch 16 loss: 0.02543616660312054 


Epoch 17: 100%|█████████████████████████████████████████████████| 1579/1579 [04:43<00:00,  5.57it/s]


Epoch 17 loss: 0.026215001586836453 


Epoch 18: 100%|█████████████████████████████████████████████████| 1579/1579 [04:43<00:00,  5.57it/s]


Epoch 18 loss: 0.0261266659992435 


Epoch 19: 100%|█████████████████████████████████████████████████| 1579/1579 [04:43<00:00,  5.57it/s]


Epoch 19 loss: 0.02602627544892106 


Epoch 20: 100%|█████████████████████████████████████████████████| 1579/1579 [04:43<00:00,  5.57it/s]


Epoch 20 loss: 0.025940908774960032 


Epoch 21: 100%|█████████████████████████████████████████████████| 1579/1579 [04:43<00:00,  5.56it/s]


Epoch 21 loss: 0.025875093421277705 


Epoch 22: 100%|█████████████████████████████████████████████████| 1579/1579 [04:43<00:00,  5.57it/s]


Epoch 22 loss: 0.025790167412118415 


Epoch 23: 100%|█████████████████████████████████████████████████| 1579/1579 [04:44<00:00,  5.54it/s]


Epoch 23 loss: 0.025720925462574352 


Epoch 24: 100%|█████████████████████████████████████████████████| 1579/1579 [04:44<00:00,  5.55it/s]


Epoch 24 loss: 0.025662086309512406 


Epoch 25: 100%|█████████████████████████████████████████████████| 1579/1579 [04:44<00:00,  5.55it/s]


Epoch 25 loss: 0.025602313055833204 


Epoch 26: 100%|█████████████████████████████████████████████████| 1579/1579 [04:44<00:00,  5.55it/s]


Epoch 26 loss: 0.025549400909234632 


Epoch 27: 100%|█████████████████████████████████████████████████| 1579/1579 [04:44<00:00,  5.55it/s]


Epoch 27 loss: 0.025484554973289966 


Epoch 28: 100%|█████████████████████████████████████████████████| 1579/1579 [04:44<00:00,  5.55it/s]


Epoch 28 loss: 0.02543984104285254 


Epoch 29: 100%|█████████████████████████████████████████████████| 1579/1579 [04:44<00:00,  5.55it/s]


Epoch 29 loss: 0.025396995584965883 


Epoch 30: 100%|█████████████████████████████████████████████████| 1579/1579 [04:44<00:00,  5.55it/s]


Epoch 30 loss: 0.02533758018210653 


Epoch 31: 100%|█████████████████████████████████████████████████| 1579/1579 [04:44<00:00,  5.55it/s]


Epoch 31 loss: 0.02530042749273392 


Epoch 32: 100%|█████████████████████████████████████████████████| 1579/1579 [04:44<00:00,  5.55it/s]


Epoch 32 loss: 0.025266294238441272 


Epoch 33: 100%|█████████████████████████████████████████████████| 1579/1579 [04:44<00:00,  5.55it/s]


Epoch 33 loss: 0.025222987000269554 


Epoch 34: 100%|█████████████████████████████████████████████████| 1579/1579 [04:44<00:00,  5.55it/s]


Epoch 34 loss: 0.025185455916717205 


Epoch 35: 100%|█████████████████████████████████████████████████| 1579/1579 [04:44<00:00,  5.55it/s]


Epoch 35 loss: 0.02515213587395068 


Epoch 36:   1%|▌                                                  | 19/1579 [00:04<05:28,  4.75it/s]


KeyboardInterrupt: 