In [3]:
import pytorch_lightning as pl
from pl_bolts.models.self_supervised import SwAV
import wandb
from torchvision.transforms import v2 as transforms
from torch.utils.data import Dataset
import pandas as pd
import skimage as ski
import os
import data
from data import MedakaDataset
from torch.utils.data import random_split, DataLoader
import yaml
from pl_bolts.transforms.self_supervised.swav_transforms import (
    SwAVTrainDataTransform,
    SwAVEvalDataTransform
)

In [4]:
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mey267[0m ([33mey267-university-of-cambridge[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
from pytorch_lightning.loggers import WandbLogger

In [6]:
os.environ['WANDB_DATA_DIR'] = '/hps/nobackup/birney/users/esther/wandb/artifacts/staging/'
os.environ['WANDB_ARTIFACT_DIR'] = '/hps/nobackup/birney/users/esther/wandb/artifacts/'
os.environ['WANDB_CACHE_DIR'] = '/hps/nobackup/birney/users/esther/wandb/.cache/'
os.environ['WANDB_TIMEOUT'] = '120'

In [7]:
def load_config(config_path):
    with open(config_path, "r") as file:
        config = yaml.load(file, Loader=yaml.FullLoader)
    return config

config = load_config("/nfs/research/birney/users/esther/medaka-img/src_files/wandb_yaml/vae-v0.yaml")

In [8]:
batch_size = 128
train_len = 800

# model
model = SwAV(
    gpus=1,
    num_samples=train_len,
    dataset='medakadataset_2024-10-03',
    batch_size=batch_size,
    num_crops=(2,4),
    learning_rate=0.0001,
    num_prototypes=50
)

In [9]:
# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=batch_size)

In [10]:
# Need a DataModule class as input for SwAV?
class MedakaDataModule(pl.LightningDataModule):
    def __init__(self, train_transforms, val_transforms, batch_size):
        super().__init__()
        self.train_len = train_len
        self.batch_size = batch_size

        self.train_transforms = None if train_transforms is None else train_transforms
        self.val_transforms = None if val_transforms is None else val_transforms

    def train_dataloader(self):
        dataset = MedakaDataset(data_csv='/nfs/research/birney/users/esther/medaka-img/src_files/train_set_2024-10-03.csv',
                                        direction_csv='/nfs/research/birney/users/esther/medaka-img/scripts/left-facing-fish.csv',
                                        src_dir='/nfs/research/birney/users/esther/medaka-img/src_files/train_2024-10-03/',
                                        transform=self.train_transforms, 
                                        config=config)
        dataset_train, _ = random_split(dataset, [self.train_len, len(dataset) - self.train_len])
        return DataLoader(dataset_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        dataset = MedakaDataset(data_csv='/nfs/research/birney/users/esther/medaka-img/src_files/train_set_2024-10-03.csv',
                                        direction_csv='/nfs/research/birney/users/esther/medaka-img/scripts/left-facing-fish.csv',
                                        src_dir='/nfs/research/birney/users/esther/medaka-img/src_files/train_2024-10-03/',
                                        transform=self.val_transforms, 
                                        config=config)
        _, dataset_val = random_split(dataset, [self.train_len, len(dataset) - self.train_len])
        return DataLoader(dataset_val, batch_size=self.batch_size)

In [11]:
train_transforms = SwAVTrainDataTransform(
    size_crops=(224, 96),
    gaussian_blur=False
)

val_transforms = SwAVEvalDataTransform(
    size_crops=(224, 96),
    gaussian_blur=False
)

medaka_dm = MedakaDataModule(train_transforms, val_transforms, batch_size)

In [12]:
wandb_logger = WandbLogger(project="SWAV",
                            log_model='all')

# fit
trainer = pl.Trainer(precision=16, accelerator='auto', logger=wandb_logger)
trainer.fit(model, medaka_dm)

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113479029801157, max=1.0…

  rank_zero_warn(
Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  linear_warmup_decay(warmup_steps, total_steps, cosine=True),

  | Name      | Type     | Params
---------------------------------------
0 | model     | ResNet   | 28.0 M
1 | criterion | SWAVLoss | 0     
---------------------------------------
28.0 M    Trainable params
0         Non-trainable params
28.0 M    Total params
55.954    Total estimated model 

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]