In [1]:
import pytorch_lightning as pl
import torch 
from pytorch_lightning.metrics.functional.classification import accuracy
from tqdm import tqdm
from src import DataModule, TIMM
import torchvision
import pandas as pd 
from pathlib import Path
from skimage import io
import glob

In [2]:
glob.glob('seresnext*')

['seresnext50_32x4d-256-val_acc=0.82656.ckpt']

In [3]:
model = TIMM.load_from_checkpoint(checkpoint_path='seresnext50_32x4d-256-val_acc=0.82656.ckpt')
model.hparams

"backbone":    seresnext50_32x4d
"batch_size":  128
"extra_data":  1
"lr":          0.0003
"max_epochs":  50
"num_workers": 0
"optimizer":   Adam
"precision":   16
"pretrained":  True
"size":        256
"subset":      0.1
"train_trans": {'RandomCrop': {'height': 256, 'width': 256}, 'HorizontalFlip': {}, 'VerticalFlip': {}, 'Normalize': {}}
"val_batches": 5
"val_trans":   {'CenterCrop': {'height': 256, 'width': 256}, 'Normalize': {}}

In [4]:
 def evaluate(model, dl):   
    model.eval()
    model.cuda()
    acc = []
    with torch.no_grad():
        t = tqdm(dl)
        for x, y in t:
            x, y = x.cuda(), y.cuda()
            y_hat = model(x)
            acc.append(accuracy(y_hat, y).item())
            t.set_description(f"acc {np.mean(acc):.5f}")

In [5]:
dm = DataModule(
    file = 'data_extra', 
    batch_size=256,
    val_trans={
        'CenterCrop': {
            'height': 256, 
            'width': 256
        },
        'Normalize': {}
    }
)
dm.setup()

Training samples:  21642
Validation samples:  5411


In [6]:
evaluate(model, dm.val_dataloader())

acc 0.82241: 100%|█████████████████████████████████████████████████| 22/22 [00:59<00:00,  2.71s/it]


In [7]:
 def evaluate_tta(model, dl, tta = 1, limit = 1):   
    model.eval()
    model.cuda()
    tta_preds = []
    for i in range(tta):
        preds = torch.tensor([]).cuda()
        labels = torch.tensor([]).cuda()
        with torch.no_grad():
            t = tqdm(dl)
            for b, (x, y) in enumerate(t):
                x, y = x.cuda(), y.cuda()
                labels = torch.cat([labels, y])
                y_hat = model(x)
                preds = torch.cat([preds, y_hat])
                if b >= int(limit*len(dl)): 
                    break
        tta_preds.append(preds)
    tta_preds = torch.stack(tta_preds).mean(axis=0)
    return accuracy(tta_preds, labels).item()

In [8]:
tta_preds = evaluate_tta(model, dm.val_dataloader(), tta=3, limit=0.33)
tta_preds

 32%|████████████████████                                           | 7/22 [00:15<00:34,  2.28s/it]
 32%|████████████████████                                           | 7/22 [00:16<00:35,  2.37s/it]
 32%|████████████████████                                           | 7/22 [00:16<00:35,  2.37s/it]


0.8154296875

In [9]:
dm = DataModule(
    file = 'data_extra', 
    batch_size=256,
    val_trans={
        'RandomCrop': {
            'height': 256, 
            'width': 256
        },
        'HorizontalFlip': {},
        'VerticalFlip': {},
        'Normalize': {}
    }
)
dm.setup()

tta_preds = evaluate_tta(model, dm.val_dataloader(), tta=3, limit=0.33)
tta_preds

  0%|                                                                       | 0/22 [00:00<?, ?it/s]

Training samples:  21642
Validation samples:  5411


 32%|████████████████████                                           | 7/22 [00:16<00:34,  2.29s/it]
 32%|████████████████████                                           | 7/22 [00:16<00:35,  2.35s/it]
 32%|████████████████████                                           | 7/22 [00:15<00:33,  2.23s/it]


0.81884765625

In [10]:
path = Path('./data/test_images')
images = os.listdir(path)
images_paths = [str(path/img) for img in images]
len(images)

1

In [11]:
import cv2

class Dataset(torch.utils.data.Dataset):
    def __init__(self, imgs, trans=None):
        self.imgs = imgs
        self.trans = trans

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, ix):
        img = cv2.imread(self.imgs[ix])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.trans:
            img = self.trans(image=img)['image']
        img = torch.tensor(img, dtype=torch.float).permute(2,0,1)
        return img

In [12]:
import albumentations as A
trans = A.Compose([
    A.RandomCrop(256,256),
    A.Normalize()
])
dataset = Dataset(images_paths, trans)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=False)

In [13]:
 def predict(model, dl, tta = 1):   
    model.eval()
    model.cuda()
    tta_preds = []
    for i in range(tta):
        preds = torch.tensor([]).cuda()
        with torch.no_grad():
            t = tqdm(dl)
            for b, x in enumerate(t):
                x = x.cuda()
                y_hat = model(x)
                preds = torch.cat([preds, y_hat])
        tta_preds.append(preds)
    tta_preds = torch.stack(tta_preds).mean(axis=0)
    return torch.argmax(tta_preds, axis=1).cpu().numpy()

In [14]:
preds = predict(model, dataloader, tta=10)
preds

100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.52it/s]
100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 38.46it/s]
100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 40.00it/s]
100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 40.00it/s]
100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 40.00it/s]
100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 38.47it/s]
100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 40.00it/s]
100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 40.00it/s]
100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 40.01it/s]
100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 40.00it/s]


array([4], dtype=int64)

In [15]:
submission = pd.DataFrame({'image_id': images, 'label': preds })
submission

Unnamed: 0,image_id,label
0,2216849948.jpg,4


In [16]:
submission.to_csv('submission.csv', index=False)