In [1]:
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
import wandb
import multiprocessing


import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import segmentation_models_pytorch as smp

from models.unet import *
from models.simple_model import *
from models.tri_unet import *
from models.divergent_nets import *
from utils.data_utils.acdc_datamodule import *
from utils.data_utils.data_utils import *
from utils.model_utils.dice_score import *
from utils.model_utils.resnet_loss import ResnetLoss 

from lightning.pytorch.callbacks import RichProgressBar



In [2]:
pl.seed_everything(42)

Seed set to 42


42

In [3]:
# Constans and Hyperparams
NUM_CLASSES = 4
MAX_EPOCHS = 500

# Big model takes lots of space in memory -> small batch size fits in
BATCH_SIZE_TRAIN = 4
BATCH_SIZE_VAL = 4
BATCH_SIZE_TEST  = 4



In [4]:
transform = DualTransform(20,0.2,0.2)
datamodule = ACDCDataModule("database", BATCH_SIZE_TRAIN,BATCH_SIZE_VAL,BATCH_SIZE_TEST,(256,256,1), convert_to_single=False, transform=transform)
datamodule.setup("fit")


In [5]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, nb_heads = 1, mid_channels=None):
        super().__init__()
        
        if not mid_channels:
            mid_channels = out_channels
            
        in_channels = in_channels * nb_heads
        out_channels = out_channels * nb_heads
        mid_channels = mid_channels * nb_heads    
        
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False,groups=nb_heads),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False, groups=nb_heads),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

In [6]:
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights

# Unet
# model = smp.Unet('resnet34',classes=NUM_CLASSES, in_channels=1)

# TriUnet
model = UNet(1,4,n_heads=2)
dconv = DoubleConv(1,64,2)


In [7]:
dconv

DoubleConv(
  (double_conv): Sequential(
    (0): Conv2d(2, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2, bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2, bias=False)
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
)

In [8]:
first_data = datamodule.train_dataloader()._get_iterator().__next__()[0]
first_data_rep = first_data.repeat(1,2,1,1)


In [11]:

# for ix,w in enumerate(model.inc.double_conv):
#     if hasattr(w, 'weight'):
#         print(ix,w.weight)
#         dconv.double_conv[ix].weight = nn.Parameter(w.weight.repeat(2,1,1,1))
#         print(ix,w.weight.repeat(2,1,1,1).shape)

In [12]:
# for ix,w in enumerate(dconv.double_conv):
#     if hasattr(w, 'weight'):
#         print(ix,w.weight)

In [13]:
# for w in dconv.double_conv:
#     if hasattr(w, 'weight'):
#         w.weight.data = w.weight.data.fill_(5)

In [14]:
model

UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(2, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2, bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2, bias=False)
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2, bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(

In [15]:
model(first_data_rep).shape

torch.Size([4, 8, 256, 256])

In [16]:
dconv(first_data_rep)

tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 1.5942],
          [0.0819, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 1.8993],
          [0.0819, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 1.8993],
          ...,
          [0.0819, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 1.8993],
          [0.0819, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 1.8993],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.9868]],

         [[0.0000, 0.8330, 0.8330,  ..., 0.8330, 0.8330, 0.5377],
          [0.2119, 0.9613, 0.9613,  ..., 0.9613, 0.9613, 0.4036],
          [0.2119, 0.9613, 0.9613,  ..., 0.9613, 0.9613, 0.4036],
          ...,
          [0.2119, 0.9613, 0.9613,  ..., 0.9613, 0.9613, 0.4036],
          [0.2119, 0.9613, 0.9613,  ..., 0.9613, 0.9613, 0.4036],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.3990, 0.3990,  ..., 0.3990, 0.3990, 0.2470],
          [0.0527, 0.7853, 0.7853,  ..., 0.7853, 0.7853, 0.3261],
          [0.0527, 0.7853, 0.7853,  ..., 0