In [14]:
import torch
import numpy as np
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import os
import faiss
import time
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import random
import cv2
import matplotlib.pyplot as plt
import torchvision.utils

import sys
sys.path.insert(1, '/home/e_radionova/PROJECT/optimal-kernels/Attention_TemplateMatching_FT')
from dataloaders import Birds_Dataset

In [15]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [16]:
dataset_path = Path('/home/e_radionova/Datasets/Caltech_birds/CUB_200_2011/dataset')

In [17]:
trans = transforms.Compose([
    transforms.Resize((390, 390), interpolation=Image.BICUBIC),
#     transforms.CenterCrop((400, 400)),
    transforms.ToTensor(),
])

train_set = Birds_Dataset(images_folder = dataset_path / 'train/images', 
                         masks_folder = dataset_path / 'train/masks',
                         img_transform = trans, 
                         masks_transform = trans)

val_set = Birds_Dataset(images_folder = dataset_path / 'test/images', 
                         masks_folder = dataset_path / 'test/masks',
                         img_transform = trans, 
                         masks_transform = trans)

In [18]:
image_datasets = {
    'train': train_set , 'val': val_set
}

batch_size_train = 8
batch_size_val = batch_size_train

dataloaders = {
    'train': DataLoader(train_set, batch_size=batch_size_train, shuffle=True, num_workers=0),
    'val': DataLoader(val_set, batch_size=batch_size_val, shuffle=True, num_workers=0)
}

dataset_sizes = {
    x: len(image_datasets[x]) for x in image_datasets.keys()
}

dataset_sizes

{'train': 5994, 'val': 5794}

In [19]:
from torchsummary import summary
import torch
import torch.nn as nn
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import copy
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from collections import defaultdict
import torch.nn.functional as F
from loss import dice_loss, calc_loss, print_metrics
# from attention_maps_funcs import get_activation, plot_attention_map
from tqdm.notebook import tqdm

In [20]:
logs_base_dir = Path('./logs_newnet')
logs_base_dir.mkdir(exist_ok=True)

In [21]:
import torchvision.transforms.functional as F
# from torchvision.transforms import InterpolationMode

In [22]:

def train_model(model, optimizer, scheduler, experiment_name, num_epochs=25):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10
    best_dice = 0
    
    writer = SummaryWriter(logs_base_dir / experiment_name)

    for epoch in tqdm(range(num_epochs)):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        since = time.time()
        i=0
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])
                    
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0
            
            for inputs, labels in tqdm(dataloaders[phase]):
                
                inputs = inputs.to(device)
                labels = labels.to(device)    

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = calc_loss(outputs, labels, metrics)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                epoch_samples += inputs.size(0)
                

            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples   
            writer.add_scalar(f'Loss_{phase}', epoch_loss.item(), global_step=epoch)
            
            dice_epoch = metrics['dice'] / epoch_samples
            writer.add_scalar(f'DICE_{phase}', dice_epoch.item(), global_step=epoch)
            
            imgs_to_tb = torchvision.utils.make_grid(inputs)
            writer.add_image('images', imgs_to_tb, global_step=epoch)
            
            outs_to_tb = torchvision.utils.make_grid(outputs)
            writer.add_image('images', outs_to_tb, global_step=epoch)
            
#             kernels = model.Conv_upd_x1.conv[0].weight.data.detach().cpu()
#             kernels = kernels[:10]
#             kernels_to_tb = torchvision.utils.make_grid(kernels, nrow=5, pad_value=255)
#             writer.add_image('kernels', kernels_to_tb, global_step=epoch)

            # deep copy the model
            if phase == 'val' and epoch_loss < best_loss:
                print("saving best loss")
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
            
            if phase == 'val' and dice_epoch > best_dice:
                print("saving best DICE")
                best_dice = dice_epoch

        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val loss: {:4f}'.format(best_loss))
    print('Best val DICE: {:4f}'.format(best_dice))
    
#     file_name = "Best_DICE_" + Model_name + '_' + model_type 
#     with open(f"{files_dir}/{file_name}.txt","a") as the_file:
#             the_file.write('Fold №{}, best DICE: {}\n'.format(fold, best_dice))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [23]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [24]:
import NewNet
model = NewNet.UNet(3).to(device)

In [25]:
optimizer_ft = optim.Adam(model.parameters(), lr=1e-3)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=15, gamma=0.1)

exp_name = 'UNet_' + datetime.now().isoformat(timespec='minutes') 

In [None]:
model = train_model(model, optimizer_ft, exp_lr_scheduler, experiment_name=exp_name, num_epochs=30) 

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch 0/29
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


train: bce: 0.364627, dice: 0.324083, loss: 0.520272


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=725.0), HTML(value='')))


val: bce: 0.308592, dice: 0.437196, loss: 0.435698
saving best loss
saving best DICE
9m 2s
Epoch 1/29
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


train: bce: 0.301022, dice: 0.477710, loss: 0.411656


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=725.0), HTML(value='')))


val: bce: 0.312786, dice: 0.549299, loss: 0.381744
saving best loss
saving best DICE
9m 5s
Epoch 2/29
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


train: bce: 0.255775, dice: 0.576939, loss: 0.339418


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=725.0), HTML(value='')))


val: bce: 0.232829, dice: 0.605729, loss: 0.313550
saving best loss
saving best DICE
9m 6s
Epoch 3/29
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


train: bce: 0.231246, dice: 0.621870, loss: 0.304688


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=725.0), HTML(value='')))


val: bce: 0.198080, dice: 0.617054, loss: 0.290513
saving best loss
saving best DICE
9m 6s
Epoch 4/29
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


train: bce: 0.218994, dice: 0.644962, loss: 0.287016


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=725.0), HTML(value='')))


val: bce: 0.211730, dice: 0.673167, loss: 0.269281
saving best loss
saving best DICE
9m 6s
Epoch 5/29
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


train: bce: 0.200738, dice: 0.679937, loss: 0.260400


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=725.0), HTML(value='')))


val: bce: 0.234078, dice: 0.668479, loss: 0.282799
9m 7s
Epoch 6/29
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))


train: bce: 0.189942, dice: 0.703429, loss: 0.243257


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=725.0), HTML(value='')))


val: bce: 0.180561, dice: 0.707678, loss: 0.236441
saving best loss
saving best DICE
9m 6s
Epoch 7/29
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=750.0), HTML(value='')))