In [1]:
import numpy as np
import time
import importlib
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import torchvision
import torch
import torch.nn as nn
from tqdm.notebook import tqdm
from PIL import Image
import os
from models.relaynet.relay_net import ReLayNet
from early_stop import EarlyStopping
from losses import CombinedLoss

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [3]:
def calc_dice(truth:torch.tensor,pred:torch.tensor)->float:
    """
    Calculates binary dice for each class and then returns the average dice across all classes

    Args:
        truth (torch.tensor): Ground truth
        pred (torch.tensor): Predictions from model

    Returns:
        float: Average dice score across all classes
    """
    pred_cpu = pred.cpu()
    true_cpu = truth.cpu()
    
    pred_unique = torch.unique(pred_cpu)
    true_unique = torch.unique(true_cpu)
    print(f"LENGTH:{len(pred_unique)}")
    dices = []
    pred_list = []
    true_list = []

    # convert the classes into 0,1,2,3,4,.....
    for i in range(len(pred_unique)):
        # print(cls)
        pred_cpu[pred_cpu==pred_unique[i]] = i
    
    for i in range(len(true_unique)):
        # print(cls)
        true_cpu[true_cpu==true_unique[i]] = i

    # Converts each class(0,1,2,3....) into boolean arrays and append them to a list
    # E.g
    # [0,0,0]    [1,1,1] [0,0,0] [0,0,0]
    # [1,1,1] => [0,0,0],[1,1,1],[0,0,0]
    # [2,2,2]    [0,0,0] [0,0,0] [1,1,1]
    # The resulting list will contain all the classes, with their index in the list corresponding to their class
    # E.g index[0] => class label 0
    for i in range(len(pred_unique)):
        pred_bool = torch.where(pred_cpu == i,True,False)
        true_bool = torch.where(true_cpu == i,True,False)
        pred_list.append(pred_bool)
        true_list.append(true_bool)

    for i in range(len(pred_unique)):
        intersection = torch.logical_and(pred_list[i],true_list[i])
        dice = 2 * intersection.sum()/(pred_list[i].sum() + true_list[i].sum())
        # print(dice)
        dices.append(dice)
    return np.average(dices)

In [4]:
def train_model(epochs:int,model,train_loader,val_loader,criterion:CombinedLoss,optimizer:torch.optim,patience:int,checkpoint_path:str):
    """
    Trains the model

    Args:
        epochs (int): num of epochs to run training for
        model (_type_): Instance of model to train
        train_loader (torch.utils.data.Dataloader): Pytorch Dataloader for training data
        val_loader (torch.utils.data.Dataloader): Pytorch Dataloader for validation data
        criterion (CombinedLoss): Loss function
        optimizer (torch.optim): Optimizer
        patience (int): num of epochs for early stopping. Early stopping will trigger if there are #patience of epochs w/o improvement
        checkpoint_path (str): Filepath to save model weights
    """
    print("Training started!")
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # start_time = time.time()
    train_losses = []
    train_dices = []
    valid_losses = []
    val_dices = []
    avg_train_losses = []
    avg_train_dice = []
    avg_valid_losses = []
    avg_val_dice = []
    running_train_loss = 0.0
    running_valid_loss = 0.0
    early_stopping = EarlyStopping(patience=patience, verbose=True,path=checkpoint_path)
    for e in range(1,epochs+1):
        print(f"Epoch {e}/{epochs}")
        ###################
        # train the model #                    
        ###################
        model.train()
        for i,data in enumerate(tqdm(train_loader)):
            optimizer.zero_grad()
            img,mask,weights = data
#             mask = torch.permute(mask,(0,3,1,2))

            img = img.to(device)
            mask = mask.to(device)
            weights = weights.to(device)

            outputs = model(img)
            print("ABC")
            # calc train dice
            soft = nn.functional.softmax(outputs,dim=1)
            maxxed = torch.argmax(soft,dim=1)
            # print(f"SOFT SHAPE:{soft.shape} MAXXED SHAPE:{maxxed.shape}")
            # print(f"SOFT :{soft[0]} MAXXED:{maxxed[0]}")
            maxxed2 = torch.argmax(outputs,dim=1)
            print(f"MAXXED :{maxxed[0]} MAXXED2:{maxxed2[0]}")
#             train_dice = calc_dice(maxxed,mask)
            train_dice = calc_dice(mask,maxxed)

            loss = criterion(outputs,mask,weights)
            loss.backward()
            optimizer.step()

            # running_stats
            running_train_loss += loss.item()
            train_losses.append(loss.item())
            train_dices.append(train_dice)
            if i%100 == 0:
                print(f'Training Epoch {e}, Iter {i + 1:5d} Running loss: {running_train_loss :.3f}, Dice: {train_dice}')
        ######################
        # validate the model #                    
        ######################
        model.eval()
        with torch.no_grad():
            for j,data in enumerate(tqdm(val_loader)):
                img,mask,weights = data
                img = img.to(device)
                mask = mask.to(device)
    #             mask = torch.permute(mask,(0,3,1,2))
                weights = weights.to(device)
                outputs = model(img)
                loss = criterion(outputs,mask,weights)
                # calc dice loss
                soft = nn.functional.softmax(outputs,dim=1)
                print(f"SOFT SHAPE:{soft.shape} SOFT:{soft}")

#                 maxxed = torch.argmax(outputs,dim=1)
#                 val_dice = calc_dice(maxxed,mask)
                val_dice = calc_dice(mask,soft)
                # running_stats
                running_valid_loss += loss.item()
                valid_losses.append(loss.item())
                val_dices.append(val_dice)
                if j%100 == 0:
                    print(f'Validation Epoch {e}, Iter {j + 1:5d} Running loss: {running_valid_loss :.3f}, Dice: {val_dice}')
        
        epoch_train_loss = np.average(train_losses)
        epoch_train_dice = np.average(train_dices)
        epoch_valid_loss = np.average(valid_losses)
        epoch_val_dice = np.average(val_dices)
        avg_train_losses.append(epoch_train_loss)
        avg_train_dice.append(epoch_train_dice)
        avg_valid_losses.append(epoch_valid_loss)
        avg_val_dice.append(epoch_val_dice)

        epoch_len = len(str(epochs+1))
        
        print_msg = (f'[{e:>{epoch_len}}/{epochs+1:>{epoch_len}}] ' +
                     f'avg_train_loss: {epoch_train_loss:.5f} ' +
                     f'avg_train_dice: {epoch_train_dice:.5f} ' +
                     f'avg_valid_loss: {epoch_valid_loss:.5f} ' +
                     f'avg_valid_dice: {epoch_val_dice:.5f}')
        
        print(print_msg)

        running_train_loss = 0.0
        running_valid_loss = 0.0
        train_losses = []
        valid_losses = []
        train_dices = []
        val_dices = []

        early_stopping((epoch_val_dice*-1),model)
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
    model.load_state_dict(torch.load(f'{checkpoint_path}'))

    return  model, avg_train_losses, avg_valid_losses

In [5]:
class MsdDataset(Dataset):
    def __init__(self,root_dir,mode,w1,w2,transform=None) -> None:
        super().__init__()
        # data
        self.data_dir = root_dir + f"/data/{mode}/"
        self.data_files = sorted(os.listdir(self.data_dir))
        # labels
        self.display_dir = root_dir + f"/display/{mode}/"
        self.display_files = sorted(os.listdir(self.display_dir))

        self.seg_dir = root_dir + f"/segmentations/{mode}/"
        self.seg_files = sorted(os.listdir(self.seg_dir))
        
        self.weight_dir = root_dir + f"/weights/{mode}/"
        self.weight_files = sorted(os.listdir(self.weight_dir))
        
        self.transform = transform
        self.w1 = w1
        self.w2 = w2

    def __len__(self):
        return len(os.listdir(self.data_dir))

    def calc_weights(self,w1,w2,label):
        """
        W1: Weighting for pixels that are proximal to tissue-transition regions. E.g pixels near the boundary/edges between to different segments
            Equates one if gradient between 2 pixels is more than 1 -> If the pixel (x) is besides
        W2: Equals one if the class labFw2el belongs to an under-represented class
        
        """
        # raw_tensor = torch.from_numpy(raw) 
        label_tensor = torch.from_numpy(label)
        
        # shape is (H,W)
        # print(label_tensor.shape)
        # Calculating the weights for W1

        # Initialise w1 weight map with all zeroes first
        w1_map = torch.zeros(label_tensor.shape)
        # print(f"Initialised w1_map \n {w1_map.shape}")
        # print("Calculating w1 map...")

        num_rows = label_tensor.shape[0]
        for row in range(1,num_rows):
            # We use row and row-1 so that we won't get an index out of bounds error while
            # iterating
            first_row = label_tensor[row-1,:]
            second_row = label_tensor[row,:]
            prev_a = None
            prev_b = None

            # iterate through each column in each rows
            # print(len(first_row))
            for col in range(len(first_row)):
                a = first_row[col]
                b = second_row[col]
                if a != b:
                    # There exists a boundary between a and b, so we should weigh these pixels
                    w1_map[row-1,col] = 1
                    w1_map[row,col] = 1
                else:
                    # if we are not at the first(leftmost) col, we check if pixels side by side are the same
                    # If they are not the same, there exists a boundary between a/b and prev_a/b so we should weigh
                    # these pixels
                    if (col != 0) and (prev_a is not None) and (prev_b is not None):
                        if a != prev_a:
                            w1_map[row-1,col] = 1
                            w1_map[row-1,col-1] = 1
                        if b != prev_b:
                            w1_map[row,col] = 1
                            w1_map[row,col-1] = 1
                    elif (col != 0 ) and (prev_a is None) and (prev_b is None):
                        raise Exception(f"Something went wrong, we were at row {row} col {col} and prev_a or prev_b is NOT none. prev_a:{prev_a}, prev_b:{prev_b}")
                    else:
                        # We are at the first(leftmost) col, and prev_a and prev_b is None so there is nothing to compare to
                        pass
                prev_a = a
                prev_b = b
        # End
        # print(f"Finished calculating W1 map")
        # print(w1_map)
        w1_map.float()
        # Initialise w2 weight map with all zeroes first
        w2_map = torch.zeros(label_tensor.shape)
        # class label/idx 2 is the "dominant" class so we will weigh the pixels with a class label that is != 2
        # w2_map = torch.eq(label_tensor,2).long()
        w2_map = torch.eq(label_tensor,2)
        # return 1 if value of w2_map = False, else return 0
        w2_map = torch.where(w2_map == False,1,0).float()
        # weighted_map = 1 + (w1*w1_map) + (w2*w2_map)
        # print(f"W1 : {w1}")
        # print(f"W1_map is \n {w1_map}\n")
        # print(f"W2 : {w2}")
        # print(f"W2_map is \n {w2_map}\n")
        # print(f"W1_weighted is : {(w1*w1_map)}\n")
        # print(f"W2_weighted is : {(w2*w2_map)}\n")
        # print(f"W1_weighted + W2_weighted is {np.add((w1*w1_map),(w2*w2_map))}")
        # print(f"w1 type is : {w1_map.type()}")
        # print(f"w2 type is : {w2_map.type()}")
        # print(f"one_map shape is {one_map.shape}")
        w1_weighted_map = w1 * w1_map
        w2_weighted_map = w2 * w2_map
        one_map = torch.ones(w1_map.shape)
        w1w2_map = torch.add(w1_weighted_map,w2_weighted_map)
        weighted_map = torch.add(one_map,w1w2_map)
        return weighted_map
        

    def __getitem__(self, index):
        # return super().__getitem__(index)
        # load images
        file_path = os.path.join(self.data_dir,self.data_files[index])
        display_path = os.path.join(self.display_dir,self.display_files[index])
        seg_path = os.path.join(self.seg_dir,self.seg_files[index])
        weight_path = os.path.join(self.weight_dir,self.weight_files[index])
#         print(f"FETCHING {file_path}")
        file = Image.open(file_path)
        display = Image.open(display_path)
        seg = Image.open(seg_path)
        # test_mask = seg.convert("RGB")
        if self.transform is not None:
            file = self.transform(file)
            display = self.transform(display)
            # seg = self.transform(seg)
        # seg_array = np.array(seg)
        # test_mask_arr = np.array(test_mask)
        seg = torch.from_numpy(np.array(seg))
        
        

        # test_mask_arr.shape
        # print(seg)
#         weights = self.calc_weights(self.w1,self.w2,seg_array)
        weights = np.load(weight_path)
        return file,seg,weights

In [6]:
class MsdTestDataset(Dataset):
    def __init__(self,root_dir,mode,w1,w2,transform=None) -> None:
        super().__init__()
        # data
        self.data_dir = root_dir + f"/data/{mode}/"
        self.data_files = sorted(os.listdir(self.data_dir))
        # labels
        self.display_dir = root_dir + f"/display/{mode}/"
        self.display_files = sorted(os.listdir(self.display_dir))

        self.seg_dir = root_dir + f"/segmentations/{mode}/"
        self.seg_files = sorted(os.listdir(self.seg_dir))
        
        self.weight_dir = root_dir + f"/weights/{mode}/"
        self.weight_files = sorted(os.listdir(self.weight_dir))
        
        self.transform = transform
        self.w1 = w1
        self.w2 = w2

    def __len__(self):
        return len(os.listdir(self.data_dir))

    def calc_weights(self,w1,w2,label):
        """
        W1: Weighting for pixels that are proximal to tissue-transition regions. E.g pixels near the boundary/edges between to different segments
            Equates one if gradient between 2 pixels is more than 1 -> If the pixel (x) is besides
        W2: Equals one if the class labFw2el belongs to an under-represented class
        
        """
        # raw_tensor = torch.from_numpy(raw) 
        label_tensor = torch.from_numpy(label)
        
        # shape is (H,W)
        # print(label_tensor.shape)
        # Calculating the weights for W1

        # Initialise w1 weight map with all zeroes first
        w1_map = torch.zeros(label_tensor.shape)
        # print(f"Initialised w1_map \n {w1_map.shape}")
        # print("Calculating w1 map...")

        num_rows = label_tensor.shape[0]
        for row in range(1,num_rows):
            # We use row and row-1 so that we won't get an index out of bounds error while
            # iterating
            first_row = label_tensor[row-1,:]
            second_row = label_tensor[row,:]
            prev_a = None
            prev_b = None

            # iterate through each column in each rows
            # print(len(first_row))
            for col in range(len(first_row)):
                a = first_row[col]
                b = second_row[col]
                if a != b:
                    # There exists a boundary between a and b, so we should weigh these pixels
                    w1_map[row-1,col] = 1
                    w1_map[row,col] = 1
                else:
                    # if we are not at the first(leftmost) col, we check if pixels side by side are the same
                    # If they are not the same, there exists a boundary between a/b and prev_a/b so we should weigh
                    # these pixels
                    if (col != 0) and (prev_a is not None) and (prev_b is not None):
                        if a != prev_a:
                            w1_map[row-1,col] = 1
                            w1_map[row-1,col-1] = 1
                        if b != prev_b:
                            w1_map[row,col] = 1
                            w1_map[row,col-1] = 1
                    elif (col != 0 ) and (prev_a is None) and (prev_b is None):
                        raise Exception(f"Something went wrong, we were at row {row} col {col} and prev_a or prev_b is NOT none. prev_a:{prev_a}, prev_b:{prev_b}")
                    else:
                        # We are at the first(leftmost) col, and prev_a and prev_b is None so there is nothing to compare to
                        pass
                prev_a = a
                prev_b = b
        # End
        # print(f"Finished calculating W1 map")
        # print(w1_map)
        w1_map.float()
        # Initialise w2 weight map with all zeroes first
        w2_map = torch.zeros(label_tensor.shape)
        # class label/idx 2 is the "dominant" class so we will weigh the pixels with a class label that is != 2
        # w2_map = torch.eq(label_tensor,2).long()
        w2_map = torch.eq(label_tensor,2)
        # return 1 if value of w2_map = False, else return 0
        w2_map = torch.where(w2_map == False,1,0).float()
        # weighted_map = 1 + (w1*w1_map) + (w2*w2_map)
        # print(f"W1 : {w1}")
        # print(f"W1_map is \n {w1_map}\n")
        # print(f"W2 : {w2}")
        # print(f"W2_map is \n {w2_map}\n")
        # print(f"W1_weighted is : {(w1*w1_map)}\n")
        # print(f"W2_weighted is : {(w2*w2_map)}\n")
        # print(f"W1_weighted + W2_weighted is {np.add((w1*w1_map),(w2*w2_map))}")
        # print(f"w1 type is : {w1_map.type()}")
        # print(f"w2 type is : {w2_map.type()}")
        # print(f"one_map shape is {one_map.shape}")
        w1_weighted_map = w1 * w1_map
        w2_weighted_map = w2 * w2_map
        one_map = torch.ones(w1_map.shape)
        w1w2_map = torch.add(w1_weighted_map,w2_weighted_map)
        weighted_map = torch.add(one_map,w1w2_map)
        return weighted_map
        

    def __getitem__(self, index):
        # return super().__getitem__(index)
        # load images
        file_path = os.path.join(self.data_dir,self.data_files[index])
        display_path = os.path.join(self.display_dir,self.display_files[index])
        seg_path = os.path.join(self.seg_dir,self.seg_files[index])
        weight_path = os.path.join(self.weight_dir,self.weight_files[index])
        pred_name = self.seg_files[index]
        # print(f"FETCHING {self.seg_files[index]}")
        file = Image.open(file_path)
        display = Image.open(display_path)
        seg = Image.open(seg_path)
        if self.transform is not None:
            file = self.transform(file)
            display = self.transform(display)
            # seg = self.transform(seg)
        seg_array = np.array(seg)
        seg = torch.from_numpy(np.array(seg))
        # print(seg)
#         weights = self.calc_weights(self.w1,self.w2,seg_array)
        weights = np.load(weight_path)
        return file,seg,weights,pred_name

In [7]:
model_params = {
        'num_channels': 1,
        'num_filters': 64,
        'kernel_h': 7,
        'kernel_w': 3,
        'kernel_c': 1,
        'stride_conv': 1,
        'pool': 2,
        'stride_pool': 2,
        'num_class': 3,
        'epochs': 6
    }
    
print("Instantiating model...")
model = ReLayNet(model_params)
moedl = model.to(device)
print("Relaynet instantiated!")
print(model)

print("Setting up criterion")
# criterion = nn.CrossEntropyLoss()
criterion = CombinedLoss()
print("Setting up criterion")

print("Initializing criterion...")
optimizer = torch.optim.SGD(model.parameters(),lr=0.1,momentum=0.9,weight_decay=0.0001)
print("Criterion initialized!")

print("Initializing scheduler...")
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=30,gamma=0.1)
print("Scheduler initialized!")

print("Initializing train loader...")
transform = {
    # "train": transforms.Compose([transforms.Grayscale(),transforms.ToTensor()]),
    # "val": transforms.Compose([transforms.Grayscale(),transforms.ToTensor()]),
    # "test": transforms.Compose([transforms.Grayscale(),transforms.ToTensor()])
    "train": transforms.Compose([transforms.ToTensor()]),
    "val": transforms.Compose([transforms.ToTensor()]),
    "test": transforms.Compose([transforms.ToTensor()])

}
train_dataset = MsdDataset(root_dir="data",mode="train",transform=transform["train"],w1=10,w2=5)
train_loader = DataLoader(train_dataset,batch_size=1,shuffle=True)
print("Train loader initialized!")

print("Initializing val loader...")
val_dataset = MsdDataset(root_dir="data",mode="val",transform=transform["val"],w1=10,w2=5)
val_loader = DataLoader(val_dataset,batch_size=1,shuffle=True)
print("Val loader initialized!")

print("Initializing test loader...")
test_dataset = MsdTestDataset(root_dir="data",mode="test",transform=transform["test"],w1=10,w2=5)
test_loader = DataLoader(test_dataset,batch_size=1,shuffle=True)
print("Test loader initialized!")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Starting training...")
trained_model = train_model(epochs=6, model=model, train_loader=train_loader, val_loader=val_loader, criterion=criterion, optimizer=optimizer, patience=2,checkpoint_path="checkpoint_1.pt")
print("Finished training!")
# test_model = test_model(model=model,state_dict_path="checkpoint.pt",test_loader=test_loader,output_path="./preds/")



Instantiating model...
Relaynet instantiated!
ReLayNet(
  (encode1): EncoderBlock(
    (conv): Conv2d(1, 64, kernel_size=(7, 3), stride=(1, 1), padding=(3, 1))
    (batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (encode2): EncoderBlock(
    (conv): Conv2d(64, 64, kernel_size=(7, 3), stride=(1, 1), padding=(3, 1))
    (batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (encode3): EncoderBlock(
    (conv): Conv2d(64, 64, kernel_size=(7, 3), stride=(1, 1), padding=(3, 1))
    (batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )




  0%|          | 0/3600 [00:00<?, ?it/s]

ABC
MAXXED :tensor([[2, 2, 2,  ..., 2, 2, 2],
        [2, 2, 2,  ..., 2, 2, 2],
        [2, 2, 2,  ..., 2, 2, 2],
        ...,
        [2, 2, 2,  ..., 2, 2, 2],
        [2, 2, 2,  ..., 2, 2, 2],
        [2, 2, 2,  ..., 2, 2, 2]], device='cuda:0') MAXXED2:tensor([[2, 2, 2,  ..., 2, 2, 2],
        [2, 2, 2,  ..., 2, 2, 2],
        [2, 2, 2,  ..., 2, 2, 2],
        ...,
        [2, 2, 2,  ..., 2, 2, 2],
        [2, 2, 2,  ..., 2, 2, 2],
        [2, 2, 2,  ..., 2, 2, 2]], device='cuda:0')
LENGTH:3
Training Epoch 1, Iter     1 Running loss: 3.235, Dice: 0.4077662527561188
ABC
MAXXED :tensor([[2, 2, 2,  ..., 2, 2, 2],
        [2, 2, 2,  ..., 2, 2, 2],
        [2, 2, 2,  ..., 2, 2, 2],
        ...,
        [2, 2, 2,  ..., 2, 2, 2],
        [2, 2, 2,  ..., 2, 2, 2],
        [2, 2, 2,  ..., 2, 2, 2]], device='cuda:0') MAXXED2:tensor([[2, 2, 2,  ..., 2, 2, 2],
        [2, 2, 2,  ..., 2, 2, 2],
        [2, 2, 2,  ..., 2, 2, 2],
        ...,
        [2, 2, 2,  ..., 2, 2, 2],
        [2, 2, 2,  ...

  0%|          | 0/1200 [00:00<?, ?it/s]

SOFT SHAPE:torch.Size([1, 3, 496, 768]) SOFT:tensor([[[[9.8592e-01, 9.9756e-01, 9.9615e-01,  ..., 9.9338e-01,
           9.8625e-01, 8.2703e-01],
          [9.9400e-01, 9.9805e-01, 9.9728e-01,  ..., 9.9707e-01,
           9.9359e-01, 8.9325e-01],
          [9.9680e-01, 9.9835e-01, 9.9860e-01,  ..., 9.9654e-01,
           9.9487e-01, 9.5309e-01],
          ...,
          [6.1109e-04, 1.0896e-04, 3.3417e-05,  ..., 1.0044e-05,
           4.5357e-05, 3.4149e-04],
          [1.3952e-03, 1.8573e-04, 4.9738e-05,  ..., 1.8283e-05,
           6.4466e-05, 5.6985e-04],
          [1.4922e-03, 2.0946e-04, 7.7062e-05,  ..., 3.1776e-05,
           1.1083e-04, 9.9457e-04]],

         [[1.2041e-03, 2.5365e-04, 3.5041e-04,  ..., 7.0032e-04,
           1.0281e-03, 3.3960e-03],
          [1.0587e-03, 2.6505e-04, 3.4383e-04,  ..., 6.2226e-04,
           9.7722e-04, 4.6117e-03],
          [1.1081e-03, 4.2410e-04, 4.0003e-04,  ..., 1.2204e-03,
           1.1904e-03, 4.6868e-03],
          ...,
          [1.8

In [None]:
def test_model(model,state_dict_path,test_loader,output_path):
    
    pred_list = []
    model.load_state_dict(torch.load(state_dict_path))
    model.to(device)
    model.eval()
    count = 0
    with torch.no_grad():
        for i,data in enumerate(tqdm(test_loader)):
            img,mask,weights,file_names = data
            img = img.to(device)
            mask = mask.to(device)
            weights = weights.to(device)
            
#             print(f"INPUT SHAPE:{img.shape}")
#             print(f"MASK SHAPE:{mask.shape}")
#             print(f"WEIGHTS SHAPE:{weights.shape}")
            pred = model(img)
            # cpu_pred = pred.clone().detach().cpu()
            # print(f"pred shape : {pred.shape}")
            # print(f"pred unique vals:{len(np.unique(cpu_pred))}")
            # print(f"pred unique vals[0]:{np.unique(cpu_pred[0][0][0][0])}")
            # print(f"pred unique vals[1]:{np.unique(cpu_pred[0][1][0][0])}")
            # print(f"pred unique vals[2]:{np.unique(cpu_pred[0][2][0][0])}")
            # print(f"pred unique vals[0] shape:{np.unique(cpu_pred[0][0][0].shape)}")
            maxxed_pred = torch.argmax(pred,dim=1)
            print(maxxed_pred)
            # print(f"maxxed_pred:{maxxed_pred}")
            # print(f"maxxed_pred shape:{maxxed_pred.shape}")
            # break
#             for each in pred:
#                 print(each.shape)
            file_pred = zip(mask,maxxed_pred,file_names)
            for truth,pred,name in file_pred:

                # print(f"pred shape is :{pred.shape}")
                # print(f"pred type is :{pred.type()}")
                pil = transforms.ToPILImage()
                truth_pil = pil(truth)
                img_pil = pil(pred.type(torch.uint8))
                # print(img_pil.size)
                truth_pil.save(f"{output_path}/truths/{name}")
                img_pil.save(f"{output_path}/preds/{name}")
                count+=1

            

In [None]:
from PIL import Image

In [None]:
test_mask = Image.open("./data_snakemake/segmentations/test/1257012_1705_2570_15354_RNHP_0001.png")
test_mask = test_mask.convert("RGB")
test_mask_arr = np.array(test_mask)
test_mask_arr.shape


In [None]:
results = test_model(model=model,state_dict_path="checkpoint_1.pt",test_loader=test_loader,output_path="./results/")

In [None]:
def calc_dice(truth_dir,pred_dir):
    """
    
    """
    truths = sorted(os.listdir(truth_dir))
    preds = sorted(os.listdir(pred_dir))
    
    truth_pred = zip(truths,preds)
    total_dice = []
    for t,p in truth_pred:
        # print(t)
        # print(p)
        # print("\n")
        t_img = Image.open(f"{truth_dir}/{t}")
        p_img = Image.open(f"{pred_dir}/{p}")
        t_arr = torch.from_numpy(np.array(t_img))
        p_arr = torch.from_numpy(np.array(p_img))
        
        # Transforming both truth n pred into boolean arrays
        t_bool = torch.eq(t_arr,p_arr)
        p_bool = torch.eq(p_arr,p_arr)
        t_b_arr = np.array(t_bool)
        p_b_arr = np.array(p_bool)
        # This gives true positives
        intersection = np.logical_and(t_arr, p_arr) #where A and B are numpy boolean arrays
        # intersection = torch.eq(t_arr, 
        # dice = 2 * intersection.sum()/(t_arr.sum() + p_arr.sum())
        # t_b_arr.sum() & p_b_arr.sum() gives us (TP + FP + FN) because True is evaluated as 1 and False is 0
        # Therefore TP=>both true(1), FP(A is 0 if B is 1), FN(Ais 1 when B is 0) and we are essentially just summing up all the 1s.
        dice = 2 * intersection.sum()/(t_b_arr.sum() + p_b_arr.sum())
        total_dice.append(dice)
        # break
    return np.mean(np.array(total_dice))


dice_metric = calc_dice("./results/truths/","./results/preds")

In [None]:
dice_metric
# b1 = np.array([True,True,True])
# b2 = np.array([True,True,True])
# b_result = np.logical_and(b1,b2)
# print(f"b_result:{b_result}")

# e1 = torch.from_numpy(np.array([1,1,3]))
# e2 = torch.from_numpy(np.array([1,2,3]))
# e_result = torch.eq(e1,e2)
# print(f"e_result:{e_result}")

In [None]:
# img,mask,weights = test_dataset.__getitem__(0)
# print(img.shape)
# print(mask.shape)
# print(weights.shape)

# test = Image.open("1043.png")
# test_arr = np.array(test)
# # print(np.unique(test_arr))
# print(test_arr.shape)
# test_arr[0][0] 

In [None]:
# import matplotlib.pyplot as plt
# m = Image.open("data/masks/train/1161012_1704_2567_15337_RNHP_0001.png")
# upsized = m.resize((768,496))
# upsized.size

In [None]:
# a = Image.open("preds/0.png")
# a_arr = np.array(a)
# print(np.unique(a_arr))

In [None]:
torch.eq(torch.tensor(0.49),torch.tensor(0.49))
a = np.logical_and([1,1],[0,0])
b = np.logical_and([0,0],[1,1])
print(a)
print(b)