<a href="https://colab.research.google.com/github/fdsig/A-Lamp/blob/master/pytorch_multipatch_ensemble.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
import io
import time
import os
import subprocess
import sys
from tqdm import tqdm
import shutil
import pandas as pd
import numpy as np
import random
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import copy
import cv2
from PIL import Image
import albumentations as A
import torch

In [None]:
def run_process(process = None, command = None):
    logname = process+'.log'
    env = os.environ.copy()
    with io.open(logname, 'wb') as writer, io.open(logname, 'rb', 1) as reader:
        process = subprocess.Popen(command, stdout=writer, shell=True, env=env)
        while process.poll() is None:
            sys.stdout.write(reader.read())
        # Read the remaining
        sys.stdout.write(reader.read())

repositories = ['https://github.com/fdsig/image_utils',
                'https://github.com/fdsig/A-Lamp']

for repo in repositories:
    run_process(command='git clone '+repo, process=repo.split('/')[-1])

other_fids = [fid for fid in os.scandir('image_utils/')]
for fid in other_fids:
    shutil.move(fid.path,fid.name)
os.system(
    'rm convit && rm image_utils && rm -rf convit &&\
     rm -rf MPADA && rm -rf image_utils')

import image_getter
pull = image_getter.Get_Ava()
pull.parse_urls()
pull.google_getter()
pull.ava_txt()
pull.download_ava_files(own_drive=True, 
                        download=True, 
                        full=True, 
                        clear_current=False)


In [13]:
def meta_process():
    df = pd.read_csv('ava_meta_with_int_id_230721.csv')
    y_gt = df['mos_float'].values
    ids = df['ID'].values
    print(len(ids))
    y_gt_std, y_gt_mean = np.std(y_gt, axis = 0), np.mean(y_gt, axis = 0)
    exclude_below = y_gt_mean-y_gt_std*4
    exclude_above = y_gt_mean+y_gt_std*4
    ids = ids[np.argwhere(y_gt>=exclude_below)].ravel()
    y_gt = y_gt[np.argwhere(y_gt>=exclude_below)].ravel()
    print(len(y_gt))
    ids = ids[np.argwhere(y_gt<=exclude_above)].ravel()
    y_gt = y_gt[np.argwhere(y_gt<=exclude_above)].ravel()
    print(len(ids),len(y_gt))
    ids_low = ids[np.argwhere(y_gt<4.99)].ravel().astype(int)
    ids_high = ids[np.argwhere(y_gt>5.01)].ravel().astype(int)
    to_include = np.concatenate((ids_low,ids_high), axis=0)
    len(to_include)
    return df[df['ID'].isin(to_include)]

y_df = meta_process()
labels = (
          fid.name.split('.')[0]
          for path in os.scandir('Images/images') 
          for fid in os.scandir(path.path))
y_g = y_df.to_dict('index')
y_g_dict = {str(y_g[pair_key]['ID']):y_g[pair_key] for pair_key in y_g}
fids = (i for i in os.scandir('Images/images'))
fids = {fid.name.split('.')[0]:{'fid':fid.path} for fid in fids}
y_g_dict = {key:{**y_g_dict[key],**fids[key]} for key in fids if key in y_g_dict}

class Square:
    def __call__(self,img): 
        dims = np.array(img.size)
        m_x = dims[[dims.argmax()]][0]
        y_axis_pad = int(m_x-dims[1])//2 ;x_axis_pad = int(m_x-dims[0])//2
        y_axis_pad+=1;x_axis_pad+=1
        pad = (x_axis_pad, y_axis_pad, x_axis_pad, y_axis_pad)
        return torchvision.transforms.functional.pad(img,pad,0,'constant')


a_transform = A.Compose([
         A.augmentations.transforms.LongestMaxSize(max_size=224),
        A.augmentations.transforms.PadIfNeeded(224,224,  always_apply=True),
        ])
import torchvision
from torchvision import datasets, transforms
from PIL import ImageFile
from torch.utils.data import DataLoader, Dataset
ImageFile.LOAD_TRUNCATED_IMAGES = True

class ava_data(Dataset):
    def __init__(self, im_dict, state = None, transform=None, a_transform=None):
        self.im_dict = im_dict
        self.transform = transform
        self.a_transform = a_transform
        self.files  = list(im_dict.keys())
        self.state = state

    def __len__(self):
        self.filelength = len(self.files)
        return self.filelength

    def __getitem__(self, idx):
        #img_path = self.im_dict[self.files[idx]]['fid']
        img = self.im_dict[self.files[idx]]['fid']
        img = cv2.imread(img)
        if len(img.shape) !=3:
            img = np.stack([np.copy(img) for i in range(3)], axis=2)
             
        #img = self.a_transform(image=img)['image']
        # converst to pillow image from arry
        # this is faster as open cv reads image 
        # faster than pillow
        # pillow also returns file read errors
        # for some image in ava dataset
        # cv2 does not. 
        img = Image.fromarray(img.astype('uint8'), 'RGB')
        
        img_transformed = self.transform(img)
        # gets one hot (binary) thresholded groud truth

        label = int(self.im_dict[self.files[idx]]['threshold'])

        # uncomment to check that lable and data loading correctly (debug)
        #print(label, self.im_dict[self.files[idx]])

        

        return img_transformed, label

batch_size = 10
lr = 0.0001
gamma = 0.1
seed = 42
sets = ['test', 'training', 'validation']
splits ={
    set_: {
        im_key:y_g_dict[im_key] for im_key in y_g_dict 
        if y_g_dict[im_key]['set']==set_
        } for set_ in sets
         }
print(f"train set n = {len(splits['training'])} \ntest_list n = {len(splits['test'])}\nvalidation_list n = {len(splits['validation'])}")
import torchvision
train_transforms = transforms.Compose(
    [
        Square(),
     transforms.Resize((224,224)),
     transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.25),
     transforms.RandomAutocontrast(p=0.25),
     transforms.RandomEqualize(p=0.25),


        transforms.ToTensor(),
      transforms.Normalize(
          [0.485, 0.456, 0.406], 
         [0.229, 0.224, 0.225]
                            )
    ]
)

val_transforms = transforms.Compose(

    [
        Square(),
         transforms.Resize((224,224)),
        transforms.ToTensor(),
      torchvision.transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )
    ]

)


test_transforms = transforms.Compose(

    [   Square(),
        transforms.Resize((224,224)),
        transforms.ToTensor(),
      torchvision.transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
        )
     

    ]
)
data_splits = {set_:
    ava_data(
        splits[set_], transform=train_transforms,a_transform=a_transform
        ) for set_ in splits
        }
#Let there be 9 samples and 1 sample in class 0 and 1 respectively
labels = [splits['training'][idx]['threshold'] for idx in splits['training']]
class_counts = np.bincount(labels)
num_samples = sum(class_counts)
#corresponding labels of samples
class_weights = [num_samples/class_counts[i] for i in range(len(class_counts))]
weights = [class_weights[labels[i]] for i in range(int(num_samples))]
sampler = torch.utils.data.WeightedRandomSampler(torch.DoubleTensor(weights), int(num_samples))
train_loader = DataLoader(
    dataset = data_splits['training'], batch_size=batch_size, sampler=sampler,
    shuffle=False)

valid_loader = DataLoader(
    dataset = data_splits['validation'], batch_size=batch_size, shuffle=True)
test_loader = DataLoader(
    dataset = data_splits['test'], batch_size=batch_size, shuffle=True)
y_g_dict = {key:y_g_dict[key] for key in list(y_g_dict.keys())}

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)
device = 'cuda'
dataloaders = {'training':train_loader, 'validation':valid_loader}
dataset_sizes = {x: len(data_splits[x]) for x in ['training', 'validation']}

device = torch.device(device)



255502
255411
255403 255403
train set n = 221649 
test_list n = 19711
validation_list n = 11670


In [32]:
def train_model(model, criterion, optimizer, scheduler, 
                num_epochs=None, 
                model_name = None, did = None):
    results = { }
    
    print(f'currently trianing {model_name}')
    print(f'{model_name} will be saved at {did+model_name}')
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['training', 'validation']:
            if phase == 'training':
                model.train()   # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in tqdm(dataloaders[phase], colour=('#FF69B4')):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'training'):
                    outputs = model(inputs,inputs,inputs,inputs,inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'training':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'training':
                scheduler.step()

            ballance = np.array([])
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            class_preds = outputs.argmax(dim=1) 
            batch_acc = metrics.balanced_accuracy_score(labels.cpu(), 
                                                         class_preds.cpu())
            ballance = np.append(ballance, batch_acc)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))
            key = 'epoch_'+str(epoch+1)+'_'+phase
            results[key]= {
                phase+' loss' : epoch_loss, 
                phase+' acc': float(epoch_acc.cpu()),
                phase+ ' ballance_acc':ballance.mean()
                }
            print(results)
            with open(did+model_name+'.json', 'w') as handle:   
                    json.dump(results, handle)


            # deep copy the model
            if phase == 'validation' and ballance.mean() > best_acc:
                best_acc = ballance.mean()
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save({'epoch':epoch, 
                            'model_state_dict':model.state_dict(),
                            'optimizer_state_dict':optimizer.state_dict()
                            }, did+model_name)
                #model save

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))


In [49]:
import torch
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.nn.functional as F
import albumentations as A
class Multi_patch(nn.Module):
    def __init__(self, _model, m_odel, mo_del, mod_el, mode_l):
        super(Multi_patch, self).__init__()
        # to condense with for loop
        self._model = _model
        self.m_odel = m_odel
        self.mo_del = mo_del
        self.mod_el = mod_el
        self.mode_l = mode_l
        # to check this bit
        self.classifier = nn.Linear(10,2)
        
        
    def forward(self, x1, x2,x3,x4,x5):
        # to condense to dictionary 
        x1 = self._model(x1)
        x2 = self.m_odel(x2)
        x3 = self.mo_del(x3)
        x4 = self.mod_el(x4)
        x5 = self.mode_l(x5)
        models = (x1,x2,x3,x4,x5)
        x = torch.cat(models, dim=1)
        x = self.classifier(F.relu(x))
        return x

# Create models and load state_dicts  
## condense with for loop  
_model = models.resnet18(pretrained=True)
m_odel = models.resnet18(pretrained=True)
mo_del = models.resnet18(pretrained=True)
mod_el = models.resnet18(pretrained=True)
mode_l = models.resnet18(pretrained=True)

models = [_model, m_odel, mo_del, mod_el, mode_l]

for mod in models:
    n_ftrs = mod.fc.in_features
    mod.fc = nn.Linear(n_ftrs,2)

model = Multi_patch(_model,m_odel,mo_del, mod_el, mode_l)
model = model.to(device)
"""load = iter(train_loader)
x1,x2 = next(load)[0],next(load)[0]
output = model(x1, x2)
print(output)"""
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=7, gamma=gamma)
criterion = nn.CrossEntropyLoss()
# step size 
"""ML Flow"""
train_model(model, criterion, optimizer, scheduler, 
            num_epochs=5, model_name='ensemble',
            did ='/content/drive/MyDrive/0.AVA/results/')

currently trianing ensemble
ensemble will be saved at /content/drive/MyDrive/0.AVA/results/ensemble
Epoch 0/4
----------


  0%|[38;2;255;105;180m          [0m| 17/22165 [00:04<1:34:48,  3.89it/s]


KeyboardInterrupt: ignored