In [2]:
import pytorch_lightning as pl
import torch 
from pytorch_lightning.metrics.functional.classification import accuracy
from tqdm import tqdm
from src import DataModule, Model
import torchvision
import pandas as pd 
from pathlib import Path
from skimage import io
import glob

In [3]:
glob.glob('seresnext*')

['seresnext50_32x4d-256-val_acc=0.59531.ckpt',
 'seresnext50_32x4d-256-val_acc=0.78125.ckpt',
 'seresnext50_32x4d-256-val_acc=0.80156.ckpt',
 'seresnext50_32x4d-256-val_acc=0.80625.ckpt',
 'seresnext50_32x4d-256-val_acc=0.83906.ckpt']

In [5]:
model = Model.load_from_checkpoint(checkpoint_path='seresnext50_32x4d-256-val_acc=0.83906.ckpt')
model.hparams

"backbone":      seresnext50_32x4d
"batch_size":    128
"es_start_from": 0
"extra_data":    1
"lr":            1e-05
"max_epochs":    10
"num_workers":   20
"optimizer":     Adam
"patience":      3
"precision":     16
"pretrained":    True
"scheduler":     {'OneCycleLR': {'max_lr': 0.001, 'total_steps': 10, 'pct_start': 0.2, 'verbose': True}}
"size":          256
"subset":        0.1
"train_trans":   {'RandomCrop': {'height': 256, 'width': 256}, 'HorizontalFlip': {}, 'VerticalFlip': {}, 'Normalize': {}}
"unfreeze":      0
"val_batches":   5
"val_trans":     {'CenterCrop': {'height': 256, 'width': 256}, 'Normalize': {}}

In [6]:
 def evaluate(model, dl):   
    model.eval()
    model.cuda()
    acc = []
    with torch.no_grad():
        t = tqdm(dl)
        for x, y in t:
            x, y = x.cuda(), y.cuda()
            y_hat = model(x)
            acc.append(accuracy(y_hat, y).item())
            t.set_description(f"acc {np.mean(acc):.5f}")

In [7]:
dm = DataModule(
    file = 'data_extra', 
    batch_size=256,
    val_trans={
        'CenterCrop': {
            'height': 256, 
            'width': 256
        },
        'Normalize': {}
    }
)
dm.setup()

Training samples:  21642
Validation samples:  5411


In [8]:
evaluate(model, dm.val_dataloader())

acc 0.82797: 100%|█████████████████████████████████████████| 22/22 [01:32<00:00,  4.21s/it]


In [9]:
 def evaluate_tta(model, dl, tta = 1, limit = 1):   
    model.eval()
    model.cuda()
    tta_preds = []
    for i in range(tta):
        preds = torch.tensor([]).cuda()
        labels = torch.tensor([]).cuda()
        with torch.no_grad():
            t = tqdm(dl)
            for b, (x, y) in enumerate(t):
                x, y = x.cuda(), y.cuda()
                labels = torch.cat([labels, y])
                y_hat = model(x)
                preds = torch.cat([preds, y_hat])
                if b >= int(limit*len(dl)): 
                    break
        tta_preds.append(preds)
    tta_preds = torch.stack(tta_preds).mean(axis=0)
    return accuracy(tta_preds, labels).item()

In [10]:
tta_preds = evaluate_tta(model, dm.val_dataloader(), tta=3, limit=0.33)
tta_preds

 32%|█████████████████▌                                     | 7/22 [00:20<00:43,  2.92s/it]
 32%|█████████████████▌                                     | 7/22 [00:20<00:43,  2.90s/it]
 32%|█████████████████▌                                     | 7/22 [00:19<00:42,  2.84s/it]


0.82080078125

In [11]:
dm = DataModule(
    file = 'data_extra', 
    batch_size=256,
    val_trans={
        'RandomCrop': {
            'height': 256, 
            'width': 256
        },
        'HorizontalFlip': {},
        'VerticalFlip': {},
        'Normalize': {}
    }
)
dm.setup()

tta_preds = evaluate_tta(model, dm.val_dataloader(), tta=3, limit=0.33)
tta_preds

  0%|                                                               | 0/22 [00:00<?, ?it/s]

Training samples:  21642
Validation samples:  5411


 32%|█████████████████▌                                     | 7/22 [00:20<00:44,  2.95s/it]
 32%|█████████████████▌                                     | 7/22 [00:19<00:42,  2.85s/it]
 32%|█████████████████▌                                     | 7/22 [00:19<00:41,  2.80s/it]


0.82861328125

In [12]:
path = Path('./data/test_images')
images = os.listdir(path)
images_paths = [str(path/img) for img in images]
len(images)

1

In [13]:
import cv2

class Dataset(torch.utils.data.Dataset):
    def __init__(self, imgs, trans=None):
        self.imgs = imgs
        self.trans = trans

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, ix):
        img = cv2.imread(self.imgs[ix])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.trans:
            img = self.trans(image=img)['image']
        img = torch.tensor(img, dtype=torch.float).permute(2,0,1)
        return img

In [14]:
import albumentations as A
trans = A.Compose([
    A.RandomCrop(256,256),
    A.Normalize()
])
dataset = Dataset(images_paths, trans)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=False)

In [15]:
 def predict(model, dl, tta = 1):   
    model.eval()
    model.cuda()
    tta_preds = []
    for i in range(tta):
        preds = torch.tensor([]).cuda()
        with torch.no_grad():
            t = tqdm(dl)
            for b, x in enumerate(t):
                x = x.cuda()
                y_hat = model(x)
                preds = torch.cat([preds, y_hat])
        tta_preds.append(preds)
    tta_preds = torch.stack(tta_preds).mean(axis=0)
    return torch.argmax(tta_preds, axis=1).cpu().numpy()

In [16]:
preds = predict(model, dataloader, tta=10)
preds

100%|████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.86it/s]
100%|████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 31.25it/s]
100%|████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 31.25it/s]
100%|████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 28.57it/s]
100%|████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 23.26it/s]
100%|████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 31.25it/s]
100%|████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 27.78it/s]
100%|████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 27.78it/s]
100%|████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 28.57it/s]
100%|████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 30.30it/s]


array([4], dtype=int64)

In [17]:
submission = pd.DataFrame({'image_id': images, 'label': preds })
submission

Unnamed: 0,image_id,label
0,2216849948.jpg,4


In [18]:
submission.to_csv('submission.csv', index=False)