In [1]:
import numpy as np
import matplotlib.pyplot as plt

SEED = 42
np.random.seed(SEED)
data = np.load("combined_patches.npy",allow_pickle=True)
data = np.array(data,dtype=np.float32)
data = np.random.permutation(data)
print("Shape of data: ", data.shape)
#print(np.unique(data))

#plt.imshow(data[0][1],cmap="gray",vmin=0,vmax=1)
#plt.colorbar()
#plt.show()

Shape of data:  (1000, 2, 512, 512)


In [2]:
train_size, val_size = 800, 150
train_data, val_data, test_data = np.split(data, [train_size, train_size+val_size])
print("Shape of train_data: ", train_data.shape)
print("Shape of val_data: ", val_data.shape)
print("Shape of test_data: ", test_data.shape)

Shape of train_data:  (800, 2, 512, 512)
Shape of val_data:  (150, 2, 512, 512)
Shape of test_data:  (50, 2, 512, 512)


In [3]:
from dataset import SheetletCellDataset
from torchvision import transforms
from matplotlib import pyplot as plt

transform = transforms.Compose([
        #transforms.Pad(padding=20,padding_mode="reflect"),
        transforms.Resize((256,256)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        #transforms.RandomRotation(15,fill=0),
        transforms.GaussianBlur(kernel_size=(11,11),sigma=(0.1,0.7)),
        #transforms.RandomAffine(0, translate=(0.05,0.05), scale=(0.9,1.15), shear=2),    
    ])

train_dataset = SheetletCellDataset(train_data, transform=transform)
val_dataset = SheetletCellDataset(val_data)
test_dataset = SheetletCellDataset(test_data)

for i in range(len(train_dataset)):
    print(train_dataset[i][0].shape,train_dataset[i][1].shape)
    plt.imshow(train_dataset[i][0][0],cmap="gray",vmin=0,vmax=1)
    plt.imshow(train_dataset[i][1][0],cmap="gray",vmin=0,vmax=1,alpha=0.5)
    plt.colorbar()
    plt.show()
    break

torch.Size([512, 512]) torch.Size([512, 512])


ValueError: Input and output must have the same number of spatial dimensions, but got input with spatial dimensions of [512] and output size of [256, 256]. Please provide input tensor in (N, C, d1, d2, ...,dK) format and output size in (o1, o2, ...,oK) format.

In [4]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=12, shuffle=True, num_workers=4,pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=12, shuffle=False, num_workers=4,pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=12, shuffle=False, num_workers=4,pin_memory=True)

#get first batch
batch = next(iter(train_dataloader))
print(batch[0].shape,batch[1].shape)


torch.Size([12, 1, 256, 256]) torch.Size([12, 1, 256, 256])


In [5]:
from model import GANModel
from utils import load_config

config = load_config("config.yaml")
model = GANModel(config, 1e-5, 1e-5)


  from .autonotebook import tqdm as notebook_tqdm


In [6]:
model

GANModel(
  (generator): Generator(
    (encoder): Encoder(
      (encoder_blocks): ModuleList(
        (0): Sequential(
          (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
          (1): LeakyReLU(negative_slope=0.2)
        )
        (1): Sequential(
          (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReLU(negative_slope=0.2)
        )
        (2): Sequential(
          (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReLU(negative_slope=0.2)
        )
        (3): Sequential(
          (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReL

In [7]:
from pytorch_lightning.loggers import WandbLogger  # Change thisfrom 
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl
import torch
wandb_logger = WandbLogger(
        project="sheetlet-gan",
        name="gan-training",
        log_model=True,
        config=config)
    
checkpoint_callback = ModelCheckpoint(
        monitor="g_val_loss",
        dirpath="checkpoints",
        filename="gan-{epoch:02d}-{val_g_loss:.2f}",
        save_top_k=3,
        mode="min")

In [8]:
trainer = pl.Trainer(
        max_epochs=500,
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=[1],
        logger=wandb_logger,  # Use wandb logger
        callbacks=[checkpoint_callback],
        log_every_n_steps=10,
    )

 

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [9]:
trainer.fit(model, train_dataloader, val_dataloader)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mignasialemanyjuv[0m ([33mignasialemanyjuv-imperial-college-london[0m). Use [1m`wandb login --relogin`[0m to force relogin


/home/ignasi/anaconda3/envs/pytorch_venv/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /home/ignasi/Desktop/StyleGAN/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name             | Type              | Params | Mode 
---------------------------------------------------------------
0 | generator        | Generator         | 37.6 M | train
1 | discriminator    | Discriminator     | 2.8 M  | train
2 | adversarial_loss | BCEWithLogitsLoss | 0      | train
3 | l1_loss          | L1Loss            | 0      | train
---------------------------------------------------------------
40.4 M    Trainable params
0         Non-trainable params
40.4 M    Total params
161.592   Total estimated model params size (MB)
79        Modules in train mode
0         Modules in eval mode


Epoch 0:  19%|█▉        | 13/67 [00:03<00:15,  3.54it/s, v_num=rr28, g_loss_l1=45.60, g_loss_adv=0.808, g_loss=46.40, d_loss_real=0.588, d_loss_fake=0.798, d_loss=0.693]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined