In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import os
import faiss
import time
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import random
import cv2
from attention_unet import Attention_block, conv_block, up_conv
import matplotlib.pyplot as plt
from dataloaders import ShapeDataset, Birds_Dataset
import torchvision.utils

In [2]:
# dataset_path = Path('/home/e_radionova/Datasets/SimpleShapes_10_classes/dataset/')
dataset_path = Path('/home/e_radionova/Datasets/Caltech_birds/CUB_200_2011/dataset')

In [3]:
trans = transforms.Compose([
    transforms.Resize((400, 400), interpolation=Image.NEAREST),
#     transforms.CenterCrop((400, 400)),
    transforms.ToTensor(),
])

train_set = Birds_Dataset(images_folder = dataset_path / 'train/images', 
                         masks_folder = dataset_path / 'train/masks',
                         img_transform = trans, 
                         masks_transform = trans)

val_set = Birds_Dataset(images_folder = dataset_path / 'test/images', 
                         masks_folder = dataset_path / 'test/masks',
                         img_transform = trans, 
                         masks_transform = trans)

In [4]:
image_datasets = {
    'train': train_set , 'val': val_set
}

batch_size_train = 4
batch_size_val = batch_size_train

dataloaders = {
    'train': DataLoader(train_set, batch_size=batch_size_train, shuffle=True, num_workers=0),
    'val': DataLoader(val_set, batch_size=batch_size_val, shuffle=True, num_workers=0)
}

dataset_sizes = {
    x: len(image_datasets[x]) for x in image_datasets.keys()
}

dataset_sizes

{'train': 5994, 'val': 5794}

In [5]:
# SEED = 42
# random.seed(SEED)
# torch.manual_seed(SEED)
# torch.cuda.manual_seed_all(SEED)
# np.random.seed(SEED)

# n_pics_to_show = batch_size_train
# inputs, masks = next(iter(dataloaders['train']))
# fig, ax = plt.subplots(n_pics_to_show, 1, figsize=(7, 15))
# for i in range(n_pics_to_show):
#     pic, label = inputs[i], masks[i] 
#     label_np = label.data.numpy().transpose(1, 2, 0) 
#     pic_np = pic.data.numpy().transpose(1, 2, 0) 
#     ax[i,0].imshow(pic_np)
#     ax[i,1].imshow(label_np)

In [6]:
from torchsummary import summary
import torch
import torch.nn as nn
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import copy
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from collections import defaultdict
import torch.nn.functional as F
from loss import dice_loss, calc_loss, print_metrics
# from attention_maps_funcs import get_activation, plot_attention_map
from tqdm.notebook import tqdm

In [7]:
logs_base_dir = Path('./logs')
logs_base_dir.mkdir(exist_ok=True)

In [8]:
activation = {}

def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

def plot_attention_map(attention_map, N=10, n_columns=5):
    normed_att_map = transforms.Normalize(0, 1)(attention_map).detach().cpu().numpy()
    plt.imshow(normed_att_map.transpose(1,2,0))
#     clear_output(wait=True)
    
    n_rows = N // n_columns + int(N // n_columns * n_columns < N)
    fig, axes = plt.subplots(n_rows, n_columns, figsize=(15,6))
    for map_i in range(N):
        if N==1:
            plt.imshow(normed_att_map[map_i])
        else:
            row_index = map_i // n_columns
            column_index = map_i % n_columns
            axes[row_index, column_index].imshow(normed_att_map[map_i])
    plt.show()

In [9]:
from template_matching_funcs import thresholding, get_Fourier_coeffs_and_kernel

def get_feature_map(attention_map, index=9, threshold=True):
    normed_att_map = attention_map.detach().cpu().numpy()
    one_layer = normed_att_map.transpose(1,2,0) #[index]
#     one_layer = (one_layer - one_layer.min()) / one_layer.max()
    plt.imshow(one_layer)
    plt.title('Choosen map')
    plt.show()
    print('Unique els: ', np.unique(one_layer))
    if threshold:
        fm_thresh = thresholding(one_layer, 1.0)
        kern = cv2.getStructuringElement(cv2.MORPH_RECT,(5,5))
        fm_thresh = cv2.morphologyEx(fm_thresh, cv2.MORPH_CLOSE, kern)
        fm_thresh = cv2.dilate(fm_thresh,kern,iterations = 1)
        fm_thresh = 255 * abs(fm_thresh / fm_thresh.max())
        plt.imshow(fm_thresh)
        plt.show()

In [10]:
def get_normalize_image(image, limit=1.):
    image = image-np.min(image)
    image = image / image.max()
    image = limit * image
    return image

In [11]:
def get_kernel_baseline(image, template, order, kernel_size,
                       morph_open, morph_close, dilate):
#     img_match = template_matching(image=image, template=template, method=cv2.TM_CCORR)
    img_thresh = thresholding(template, 1.0)
    
    kern = cv2.getStructuringElement(cv2.MORPH_RECT,(5,5))  # np.ones((5,5),np.uint8)
    if morph_open:
        img_thresh = cv2.morphologyEx(img_thresh, cv2.MORPH_OPEN, kern)
    if morph_close:
        img_thresh = cv2.morphologyEx(img_thresh, cv2.MORPH_CLOSE, kern)
    if dilate:
        img_thresh = cv2.dilate(img_thresh,kern,iterations = 1)
        
    _, kernel = get_Fourier_coeffs_and_kernel(img_thresh, order=order, kernel_size=kernel_size)

    return kernel, img_thresh

In [12]:
def plot_attention_map(attention_map, N=30, n_columns=5):
    normed_att_map = transforms.Normalize(0, 1)(attention_map)
    resize_map = normed_att_map.detach().cpu().numpy()
    
#     clear_output(wait=True)
    n_rows = N // n_columns + int(N // n_columns * n_columns < N)
    fig, axes = plt.subplots(n_rows, n_columns, figsize=(40,20))
    for map_i in range(N):
        if N==1:
            plt.imshow(resize_map[map_i])
        else:
            row_index = map_i // n_columns
            column_index = map_i % n_columns
            axes[row_index, column_index].imshow(resize_map[map_i])
    plt.show()

In [13]:
import torchvision.transforms.functional as F
# from torchvision.transforms import InterpolationMode

In [14]:
def kernel_torch(att_map, kernel_size):
    uniq_els = torch.unique(att_map)
    mean = torch.mean(uniq_els)
    m = torch.nn.Threshold(mean, 0., inplace=False)
    thresh = m(att_map)
    height = att_map.shape[-2]
    width = att_map.shape[-1]
    kernel_resize = F.resize(thresh, kernel_size)
    return kernel_resize

In [15]:

def train_model(model, optimizer, scheduler, experiment_name, num_epochs=25):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10
    best_dice = 0
    
    writer = SummaryWriter(logs_base_dir / experiment_name)

    for epoch in tqdm(range(num_epochs)):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        since = time.time()
        i=0
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])
                    
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0
            
            for inputs, labels in tqdm(dataloaders[phase]):
                
                inputs = inputs.to(device)
                labels = labels.to(device)    

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
#                     outputs_att = activation['Att4']
#                     if i ==10:
#                         plot_attention_map(activation['Att4'][0]) 
#                         print(activation['Att4'].shape)
#                     i+=1
                    #####----------------------
                    
#                     thresh_detach = thresh.permute(1,2,0).detach().cpu().numpy()
#                     outputs_detach = outputs_att.permute(1,2,0).detach().cpu().numpy()
#                     a = outputs_detach.shape[0]
#                     b = outputs_detach.shape[1]
#                     kernel_size = 20
#                     kernel_per_img, img_thresh = get_kernel_baseline(image=outputs_detach, 
#                                                                      template=thresh_detach, 
#                                                                      order=10, 
#                                                                      kernel_size=kernel_size,
#                                                                      morph_open=False, 
#                                                                      morph_close=False, 
#                                                                      dilate=True)
    
#                     fig, (ax0, ax3) = plt.subplots(1, 2, figsize=(10,10))
#                     ax0.imshow(inputs[0].permute(1,2,0).detach().cpu().numpy())
#                     ax0.set_title('Image')
#                     ax1.imshow(outputs_detach)
#                     ax1.set_title(f'Att map, shape: {outputs_detach.shape}')
#                     ax2.imshow(kernel_per_img)
#                     ax2.set_title(f'Kernel {kernel_per_img.shape[0]}x{kernel_per_img.shape[1]}')
#                     kernel_resize = F.resize(thresh, kernel_size)
#                     print('new shape:', kernel_resize.shape)
#                     kernel_size = 20
#                     kernel_upd = kernel_torch(outputs_att, kernel_size)
#                     ax3.imshow(kernel_upd.permute(1,2,0).detach().cpu().numpy())
#                     ax3.set_title(f'Kernel resize to {kernel_size}')
#                     plt.show()

                    loss = calc_loss(outputs, labels, metrics)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                epoch_samples += inputs.size(0)
                

            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples   
            writer.add_scalar(f'Loss_{phase}', epoch_loss.item(), global_step=epoch)
            
            dice_epoch = metrics['dice'] / epoch_samples
            writer.add_scalar(f'DICE_{phase}', dice_epoch.item(), global_step=epoch)
            
#             imgs_to_tb = torchvision.utils.make_grid(inputs)
#             writer.add_image('images', imgs_to_tb, global_step=epoch)

            # deep copy the model
            if phase == 'val' and epoch_loss < best_loss:
                print("saving best loss")
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
            
            if phase == 'val' and dice_epoch > best_dice:
                print("saving best DICE")
                best_dice = dice_epoch

        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val loss: {:4f}'.format(best_loss))
    print('Best val DICE: {:4f}'.format(best_dice))
    
#     file_name = "Best_DICE_" + Model_name + '_' + model_type 
#     with open(f"{files_dir}/{file_name}.txt","a") as the_file:
#             the_file.write('Fold №{}, best DICE: {}\n'.format(fold, best_dice))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [18]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [20]:
# import attention_unet
# model = attention_unet.AttU_Net(3, 3).to(device)
import attention_kernel2conv1
model = attention_kernel2conv1.AttU_Net(3, 3).to(device)

In [21]:
# model.Att4.register_forward_hook(get_activation('Att4'))

In [22]:
optimizer_ft = optim.Adam(model.parameters(), lr=1e-3)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=15, gamma=0.1)

exp_name = 'Based_attention_50epochs_b4_' + datetime.now().isoformat(timespec='minutes') 

In [None]:
model = train_model(model, optimizer_ft, exp_lr_scheduler, experiment_name=exp_name, num_epochs=50) 

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))

Epoch 0/49
----------
LR 0.001




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1499.0), HTML(value='')))




train: bce: 0.280112, dice: 0.521580, loss: 0.379266


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1449.0), HTML(value='')))


val: bce: 0.298878, dice: 0.576187, loss: 0.361345
saving best loss
saving best DICE
27m 6s
Epoch 1/49
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1499.0), HTML(value='')))


train: bce: 0.221436, dice: 0.635679, loss: 0.292878


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1449.0), HTML(value='')))


val: bce: 0.221745, dice: 0.667982, loss: 0.276881
saving best loss
saving best DICE
27m 17s
Epoch 2/49
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1499.0), HTML(value='')))


train: bce: 0.193849, dice: 0.691479, loss: 0.251185


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1449.0), HTML(value='')))


val: bce: 0.253459, dice: 0.688098, loss: 0.282681
saving best DICE
27m 18s
Epoch 3/49
----------
LR 0.001


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1499.0), HTML(value='')))