In [12]:
import pytorch_lightning as pl
import torch 
from pytorch_lightning.metrics.functional.classification import accuracy
from tqdm import tqdm
from src import DataModule, Model
import torchvision
import pandas as pd 
from pathlib import Path
from skimage import io
import glob
import os

In [22]:
models=glob.glob('efficientnet*')
models

['efficientnet_b2a-256-fold_1-val_acc=0.87692.ckpt',
 'efficientnet_b2a-256-fold_2-val_acc=0.87396.ckpt',
 'efficientnet_b2a-256-fold_3-val_acc=0.87396.ckpt',
 'efficientnet_b2a-256-fold_4-val_acc=0.87671.ckpt',
 'efficientnet_b2a-256-fold_5-val_acc=0.87246.ckpt']

In [23]:
def evaluate_tta(model, dl, tta = 1, limit = 1):   
    model.eval()
    model.cuda()
    tta_preds = []
    for i in range(tta):
        preds = torch.tensor([]).cuda()
        labels = torch.tensor([]).cuda()
        with torch.no_grad():
            t = tqdm(dl)
            for b, (x, y) in enumerate(t):
                x, y = x.cuda(), y.cuda()
                labels = torch.cat([labels, y])
                y_hat = model(x)
                preds = torch.cat([preds, y_hat])
                if b >= int(limit*len(dl)): 
                    break
        tta_preds.append(preds)
    tta_preds = torch.stack(tta_preds).mean(axis=0)
    return accuracy(tta_preds, labels).item()

In [27]:
dm = DataModule(
    file = 'data_extra', 
    batch_size=256,
    val_trans={
        'CenterCrop': {
            'height': 256, 
            'width': 256
        },
        'Normalize': {}
    }
)
dm.setup()

Training samples:  21642
Validation samples:  5411


In [24]:
path = './data/test_images'
images = os.listdir(path)
images_paths = [f'{path}/{img}' for img in images]
len(images)

1

In [31]:
model = Model.load_from_checkpoint(checkpoint_path='efficientnet_b2a-256-fold_1-val_acc=0.87692.ckpt')
model.hparams

"backbone":      efficientnet_b2a
"batch_size":    128
"es_start_from": 3
"extra_data":    1
"folds":         5
"gpus":          1
"lr":            1e-05
"max_epochs":    10
"num_workers":   20
"optimizer":     Adam
"patience":      3
"pin_memory":    True
"precision":     16
"pretrained":    True
"scheduler":     {'OneCycleLR': {'max_lr': 0.005, 'total_steps': 10, 'pct_start': 0.2, 'verbose': True}}
"size":          256
"subset":        0
"train_trans":   {'RandomCrop': {'height': 256, 'width': 256}, 'HorizontalFlip': {}, 'VerticalFlip': {}, 'Normalize': {}}
"unfreeze":      0
"val_batches":   1.0
"val_trans":     {'CenterCrop': {'height': 256, 'width': 256}, 'Normalize': {}}

In [32]:
preds = evaluate_tta(model,  dm.val_dataloader(), tta=10,limit=0.33)
preds




  0%|          | 0/22 [00:00<?, ?it/s][A[A[A


  5%|▍         | 1/22 [00:02<00:48,  2.32s/it][A[A[A


  9%|▉         | 2/22 [00:04<00:46,  2.34s/it][A[A[A


 14%|█▎        | 3/22 [00:07<00:44,  2.35s/it][A[A[A


 18%|█▊        | 4/22 [00:09<00:42,  2.38s/it][A[A[A


 23%|██▎       | 5/22 [00:11<00:40,  2.40s/it][A[A[A


 27%|██▋       | 6/22 [00:14<00:39,  2.46s/it][A[A[A


 32%|███▏      | 7/22 [00:16<00:36,  2.44s/it][A[A[A



  0%|          | 0/22 [00:00<?, ?it/s][A[A[A[A



  5%|▍         | 1/22 [00:02<00:48,  2.31s/it][A[A[A[A



  9%|▉         | 2/22 [00:04<00:46,  2.31s/it][A[A[A[A



 14%|█▎        | 3/22 [00:06<00:43,  2.31s/it][A[A[A[A



 18%|█▊        | 4/22 [00:09<00:41,  2.31s/it][A[A[A[A



 23%|██▎       | 5/22 [00:11<00:39,  2.31s/it][A[A[A[A



 27%|██▋       | 6/22 [00:13<00:37,  2.32s/it][A[A[A[A



 32%|███▏      | 7/22 [00:16<00:34,  2.32s/it][A[A[A[A




  0%|          | 0/22 [00:00<?, ?it/s][A[A[A[A

KeyboardInterrupt: 

In [33]:

dm = DataModule(
    file = 'data_extra', 
    batch_size=256,
    val_trans={
        'RandomCrop': {
            'height': 256, 
            'width': 256
        },
        'HorizontalFlip': {},
        'VerticalFlip': {},
        'Normalize': {}
    }
)
dm.setup()

Training samples:  21642
Validation samples:  5411


In [34]:
tta_preds = evaluate_tta(model, dm.val_dataloader(), tta=3, limit=0.33)
tta_preds









  0%|          | 0/22 [00:00<?, ?it/s][A[A[A[A[A[A[A[A







  5%|▍         | 1/22 [00:02<00:51,  2.44s/it][A[A[A[A[A[A[A[A







  9%|▉         | 2/22 [00:04<00:49,  2.47s/it][A[A[A[A[A[A[A[A







 14%|█▎        | 3/22 [00:07<00:46,  2.43s/it][A[A[A[A[A[A[A[A







 18%|█▊        | 4/22 [00:09<00:43,  2.41s/it][A[A[A[A[A[A[A[A







 23%|██▎       | 5/22 [00:12<00:40,  2.39s/it][A[A[A[A[A[A[A[A







 27%|██▋       | 6/22 [00:14<00:37,  2.37s/it][A[A[A[A[A[A[A[A







 32%|███▏      | 7/22 [00:16<00:35,  2.38s/it][A[A[A[A[A[A[A[A








  0%|          | 0/22 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A








  5%|▍         | 1/22 [00:02<00:50,  2.41s/it][A[A[A[A[A[A[A[A[A








  9%|▉         | 2/22 [00:04<00:48,  2.41s/it][A[A[A[A[A[A[A[A[A








 14%|█▎        | 3/22 [00:07<00:45,  2.40s/it][A[A[A[A[A[A[A[A[A








 18%|█▊        | 4/22 [00:09<00:42,  2.39s/it][A[A[A

0.8876953125

In [35]:
class FinalModelTTA(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.trans = torch.nn.Sequential(
            torchvision.transforms.RandomCrop(256),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.RandomVerticalFlip()
        )
    
    def forward(self, x, tta : int = 1):
        x = x.float() / 255.
        x = x.permute(2, 0, 1)
        imgs = torch.stack([self.trans(x) for i in range(tta)])
        y_hat = self.model(imgs).mean(dim=0)
        return y_hat

In [36]:
from tqdm import tqdm 

for ix, model_name in tqdm(enumerate(models)):
    # load model
    model = Model.load_from_checkpoint(checkpoint_path=model_name)
    
    # export model tta
    final_model = FinalModelTTA(model.cpu())
    script = torch.jit.script(final_model.cpu())
    torch.jit.save(script, f"efficientnet_fold_{ix+1}_tta.pt")












0it [00:00, ?it/s][A[A[A[A[A[A[A[A[A[A[A










1it [00:01,  1.55s/it][A[A[A[A[A[A[A[A[A[A[A










2it [00:03,  1.78s/it][A[A[A[A[A[A[A[A[A[A[A










3it [00:06,  2.12s/it][A[A[A[A[A[A[A[A[A[A[A










4it [00:07,  1.85s/it][A[A[A[A[A[A[A[A[A[A[A










5it [00:10,  2.05s/it][A[A[A[A[A[A[A[A[A[A[A
