In [1]:
import numpy as np
from pathlib import Path
import os
import torch
from torch import Tensor
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.nn import BCEWithLogitsLoss
import torch.nn.functional as F
from tqdm import tqdm
import json
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from torch.autograd import Variable
from preprocessing import *
from model import *
import time
import wandb 

In [2]:
class Average(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def find_latest_model(dir):
    model_paths, epochs = [], []
    for path in Path(dir).glob('*.pt'):
        if 'epoch' not in path.stem:
            continue
        model_paths.append(path)
        parts = path.stem.split('_')
        epochs.append(int(parts[-1]))

    if len(epochs) > 0:
        epochs = np.array(epochs)
        max_idx = np.argmax(epochs)
        return model_paths[max_idx]
    else:
        return None

In [3]:
class MetricsMethod(object):
    def __init__(self) -> None:
        pass
    def PA(self):
        pass
    def CPA(self):
        pass
    def MPA(self):
        pass
    def Dice(self):
        pass
    def IOU(self):
        pass
    def MIOU(self):
        pass
    def mAP(self):
        pass

In [4]:
class DiceLoss(object):
    def __init__(self, multiclass: bool = False, reduce_batch_first=True) -> None:
        self.fn = self.multiclass_dice_coeff if multiclass else self.dice_coeff
        self.reduce_batch_first = reduce_batch_first

    def loss(self, input: Tensor, target: Tensor) -> Tensor:
        return 1 - self.fn(input, target, reduce_batch_first=self.reduce_batch_first)

    def dice_coeff(self, input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
        # Average of Dice coefficient for all batches, or for a single mask
        assert input.size() == target.size()
        assert input.dim() == 3 or not reduce_batch_first

        sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)

        inter = 2 * (input * target).sum(dim=sum_dim)
        sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)
        print(f"{inter=}")
        print(f"{sets_sum=}")
        sets_sum = torch.where(sets_sum == 0, inter, sets_sum)
        print(f"{sets_sum=}\n")

        dice = (inter + epsilon) / (sets_sum + epsilon)
        return dice.mean()

    def multiclass_dice_coeff(self, input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
        # Average of Dice coefficient for all classes
        return self.dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon)

In [5]:
def GraphVisualization(dataset: MyDataset, model=None, col=5, target_dir: str="./"):
    rows = ['Images', 'Ground\nTruth\nMasks', 'Ground\nTruth\nFusions',
            'Prediction\nMasks', 'Prediction\nFusions', 'Prediction V.S.\nGround Truth']
    if model is None:
        fig, axes = plt.subplots(nrows=3, ncols=col, figsize=(10,10))

        for i in range(3):
            axes[i][0].annotate(rows[i], xy=(0, 0.5), xytext=(-30,60),
                                xycoords='axes points', textcoords='offset points',
                                size='large', ha='center', va='center')

        for i in range(5):
            filename = Path(dataset.imgPaths[i]).stem
            data = dataset.__getitem__(i)
            mask = np.array(data[1]).squeeze()
            invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
                                                                 std = [ 1/0.229, 1/0.224, 1/0.225 ]),
                                            transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
                                                                 std = [ 1., 1., 1. ])])
            img = invTrans(data[0])
            img = np.array(img).transpose(1,2,0)

            axes[0][i].set_title(filename, {'fontsize': 8})
            axes[0][i].get_xaxis().set_visible(False)
            axes[0][i].get_yaxis().set_visible(False)
            axes[0][i].imshow(img)
            axes[1][i].get_xaxis().set_visible(False)
            axes[1][i].get_yaxis().set_visible(False)
            axes[1][i].imshow(mask, cmap='magma')
            axes[2][i].get_xaxis().set_visible(False)
            axes[2][i].get_yaxis().set_visible(False)
            axes[2][i].imshow(img)
            axes[2][i].imshow(mask, cmap='twilight', alpha=0.6)

        fig.tight_layout(h_pad=-25)
        plt.savefig(os.path.join(target_dir, 'sample.png'), dpi=500)
        plt.show()
    else:
        pass

In [6]:
def validate(model, val_loader, criterion):
    losses = Average()
    model.eval()
    with torch.no_grad():

        for i, (input, target) in enumerate(val_loader):
            input_img  = Variable(input_img).to(device)
            masks_true = Variable(masks_true).to(device)

            masks_pred = model(input_img)
            output = torch.argmax(masks_pred, dim=1)
            # masks_pred = F.softmax(masks_pred, dim=1).float(),
            # masks_true = F.one_hot(masks_true, 8).squeeze(1).permute(0, 3, 1, 2).float()
            loss = criterion(output, masks_true)

            losses.update(loss.item(), masks_true.size(0))

    return losses.avg

In [7]:
def train(config, train_loader, model, criterion, optimizer, scheduler, validation):
    wandb.watch(model, criterion=criterion, log="all", log_freq=10)

    latest_model_path = find_latest_model(model_dir)
    best_model_path = os.path.join(*[model_dir, 'model_best.pt'])

    if latest_model_path is not None:
        state = torch.load(latest_model_path)
        epoch = state['epoch']
        model.load_state_dict(state['model'])
        epoch = epoch

        assert Path(best_model_path).exists() == True, f'best model path {best_model_path} does not exist'
        best_state = torch.load(latest_model_path)
        min_val_los = best_state['valid_loss']

        print(f'Restored model at epoch {epoch}. Min validation loss : {min_val_los}')
        epoch += 1
        print(f'Started training model from epoch {epoch}')
    else:
        print('Started training model from epoch 0')
        epoch = 0
        min_val_los = 9999

    valid_losses = []
    start_time = time.time()
    for epoch in range(epoch, config['n_epoch'] + 1):

        tq = tqdm(total=(len(train_loader) * config['batch_size']))
        tq.set_description(f'Epoch {epoch}')

        running_losses = Average()

        model.train()
        for i, (input_img, masks_true) in enumerate(train_loader):
            input_img  = Variable(input_img, requires_grad=True).to(device, dtype=torch.float32)
            masks_true = Variable(masks_true, requires_grad=True).to(device, dtype=torch.long)

            masks_pred = model(input_img)
            # masks_pred = F.softmax(masks_pred, dim=1).float()
            # masks_true = F.one_hot(masks_true.squeeze(1), 8).permute(0, 3, 1, 2).float()
            masks_true = F.one_hot(masks_true.squeeze(1), 8).permute(0, 3, 1, 2).float()
            print("pred shape: ", masks_pred.size())
            print("true shape: ", masks_true.size())
            loss = criterion(masks_pred, masks_true)
            running_losses.update(loss)

            tq.set_postfix(loss='{:.5f}'.format(running_losses.avg))
            tq.update(config['batch_size'])

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        scheduler.step()
        valid_loss = validation(model, valid_loader, criterion)
        valid_losses.append(valid_loss)
        print(f'valid_loss = {valid_loss:.5f}')
        tq.close()
        
        wandb.log({"training_loss": running_losses.avg})
        wandb.log({"valid_loss": valid_loss})
        
        epoch_model_path = os.path.join(*[model_dir, f'model_epoch_{epoch}.pt'])
        torch.save({
            'model': model.state_dict(),
            'epoch': epoch,
            'valid_loss': valid_loss,
            'train_loss': running_losses.avg
        }, epoch_model_path)

        if valid_loss < min_val_los:
            min_val_los = valid_loss
            torch.save({
                'model': model.state_dict(),
                'epoch': epoch,
                'valid_loss': valid_loss,
                'train_loss': running_losses.avg
            }, best_model_path)
    
    finished_training_time = (time.time()-start_time)
    print(f"training_time(s): ", finished_training_time)
    wandb.log({"training_time(s)": finished_training_time})

In [8]:
config = dict(
    n_epoch         = 50,
    batch_size      = 1,
    lr              = 0.001,
    num_workers     = 0,
    momentum        = 0.9,
    weight_decay    = 1e-4,
    dataset         = "CCAgT",
    model           = "ResUNet",
    optimizer       = "Adam",
    scheduler       = "StepLR"
)

wandb_init = dict(
#     job_type: Optional[str] = None,
#     dir = None,
    config = config,
    project = "ResUNet with CCAgT dataset",
#     entity = None,
#     reinit = None,
    # tags = ['wgan_gp_1v1'],
#     group = None,
    name = None,
    notes = None,
#     magic = None,
#     config_exclude_keys = None,
#     config_include_keys = None,
#     anonymous = None,
    mode = "online",  # "online","offline","disabled"
#     allow_val_change = None,
#     resume = None,
#     force = None,
#     tensorboard = None,
#     sync_tensorboard = None,
#     monitor_gym = None,
    save_code = True,
#     settings=None
)

model_dir = "./model"
orig_img_dir = "./dataset/images"
orig_msk_dir = "./dataset/masks"
save_json = "./dataset/dataset.json"
save_samples = "./dataset/samples"

if save_samples != '':
    os.makedirs(save_samples, exist_ok=True)
    for sample in Path(save_samples).glob('*.jpg'):
        os.remove(str(sample))

if not os.path.exists(save_json):
    obtain_path(img_dir=orig_img_dir, mask_dir=orig_msk_dir, target_path=str(save_json))

os.makedirs(model_dir, exist_ok=True)

In [9]:
gpu_id = 0

torch.manual_seed(42)
torch.backends.cudnn.deterministic = True

torch.cuda.set_device(gpu_id)
# device = torch.device("cuda:{}".format(str(gpu_id)) if torch.cuda.is_available() else "cpu")
device = "cpu"
# print("device", torch.cuda.current_device(), torch.cuda.get_device_name(torch.cuda.current_device()))

In [10]:
channel_means = [0.485, 0.456, 0.406]
channel_stds  = [0.229, 0.224, 0.225]

dataset = divide_dataset(save_json, [0.7,0.1,0.2])

train_tsfm = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize(channel_means, channel_stds),
                                transforms.RandomCrop((832, 832)), 
                                transforms.RandomRotation(90), 
                                transforms.RandomHorizontalFlip()])
val_tsfm = transforms.Compose([ transforms.ToTensor(), 
                                transforms.Normalize(channel_means, channel_stds),
                                transforms.RandomCrop((832, 832))])
test_tsfm = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize(channel_means, channel_stds),
                                transforms.RandomCrop((832, 832))])
train_mask_tsfm = transforms.Compose([transforms.ToTensor(),
                                      transforms.RandomCrop((832, 832)),
                                      transforms.RandomRotation(90),
                                      transforms.RandomHorizontalFlip()])
mask_tsfm = transforms.Compose([transforms.ToTensor(),
                                transforms.RandomCrop((832, 832)), 
                                transforms.RandomRotation(90)])

train_set = MyDataset(dataset, train_tsfm, train_mask_tsfm, 'train')
valid_set = MyDataset(dataset, val_tsfm, mask_tsfm, 'valid')
test_set = MyDataset(dataset, test_tsfm, mask_tsfm, 'test')

train_loader = DataLoader(  train_set,
                            config['batch_size'],
                            shuffle=True,
                            pin_memory=torch.cuda.is_available(),
                            num_workers=config['num_workers'])
valid_loader = DataLoader(  valid_set,
                            config['batch_size'],
                            shuffle=True,
                            pin_memory=torch.cuda.is_available(),
                            num_workers=config['num_workers'])
test_loader = DataLoader(   test_set,
                            config['batch_size'],
                            shuffle=True,
                            pin_memory=torch.cuda.is_available(),
                            num_workers=config['num_workers'])

# GraphVisualization(test_set, model=None, col=5, target_dir=save_samples)

In [11]:
wandb.login()
os.environ["WANDB_API_KEY"] = "a9932db05eeba1bfd135b895b1e586738f267083"
os.environ['WANDB_NOTEBOOK_NAME'] = 'main.ipynb'

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mohmygoose0410[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [12]:
model = ResUNet(Block=ResBlock, DecBlock=DecBlock)
optimizer = torch.optim.Adam(model.parameters(), config['lr'],
                             weight_decay=config['weight_decay'])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 30, 0.1)
criterion = BCEWithLogitsLoss().to(device)

model.to(device)

run = wandb.init(**wandb_init)

train(config, train_loader, model, criterion, optimizer, scheduler, validate)

run.finish()

Started training model from epoch 0


Epoch 0:   0%|          | 0/6537 [00:00<?, ?it/s]

pred shape:  torch.Size([1, 8, 832, 832])
true shape:  torch.Size([1, 1, 832, 832])


ValueError: Target size (torch.Size([1, 1, 832, 832])) must be the same as input size (torch.Size([1, 8, 832, 832]))