In [113]:
import pytorch_lightning as pl
import torch 
from pytorch_lightning.metrics.functional.classification import accuracy
from tqdm import tqdm
from src import DataModule, TIMM
import torchvision
import pandas as pd 
from pathlib import Path
from skimage import io
import glob
import timm

torch.__version__, timm.__version__

('1.7.0', '0.3.2')

In [114]:
models = glob.glob('resnet18*-pseudo-*.ckpt')
models

['resnet18-256-pseudo-val_acc=0.85781.ckpt']

In [115]:
model = TIMM.load_from_checkpoint(checkpoint_path='resnet18-256-pseudo-val_acc=0.85781.ckpt')
model.hparams

"backbone":        resnet18
"batch_size":      256
"es_start_from":   0
"lr":              0.0003
"max_epochs":      50
"num_workers":     20
"optimizer":       Adam
"patience":        3
"precision":       16
"pseudolabelling": 1
"size":            256
"subset":          0.1
"train_trans":     {'RandomCrop': {'height': 256, 'width': 256}, 'HorizontalFlip': {}, 'VerticalFlip': {}, 'Normalize': {}}
"val_batches":     5
"val_trans":       {'CenterCrop': {'height': 256, 'width': 256}, 'Normalize': {}}

In [116]:
dm = DataModule(
    file = 'data_extra', 
    batch_size=64,
    val_trans={
        'CenterCrop': {
            'height': model.hparams.size, 
            'width': model.hparams.size
        },
        'Normalize': {}
    }
)
dm.setup()

Training samples:  21642
Validation samples:  5411


In [117]:
 def evaluate(model, dl):   
    model.eval()
    model.cuda()
    acc = []
    with torch.no_grad():
        t = tqdm(dl)
        for x, y in t:
            x, y = x.cuda(), y.cuda()
            y_hat = model(x)
            acc.append(accuracy(y_hat, y).item())
            t.set_description(f"acc {np.mean(acc):.5f}")

In [118]:
evaluate(model, dm.val_dataloader())

acc 0.79912: 100%|██████████| 85/85 [00:38<00:00,  2.21it/s]


In [32]:
def evaluate_tta(model, ds, tta = 0, limit = 1):   
    model.eval()
    model.cuda()
    acc = []
    with torch.no_grad():
        t = tqdm(range(len(ds)))
        for ix in t:
            y = ds[ix][1].unsqueeze(0).cuda()
            imgs = torch.stack([ds[ix][0] for i in range(tta+1)]).cuda()
            y_hat = model(imgs).mean(axis=0).unsqueeze(0)
            acc.append(accuracy(y_hat, y).item())
            t.set_description(f"acc {np.mean(acc):.5f}")
            if ix >= int(limit*len(ds)):
                break

In [33]:
evaluate_tta(model, dm.val_ds, limit=0.1)

acc 0.82288:  10%|▉         | 541/5411 [00:23<03:35, 22.55it/s]


0.8228782287822878

In [34]:
dm = DataModule(
    file = 'data_extra', 
    val_trans={
        'RandomCrop': {
            'height': model.hparams.size, 
            'width': model.hparams.size
        },
        'HorizontalFlip': {},
        'VerticalFlip': {},
        'Normalize': {}
    }
)
dm.setup()

evaluate_tta(model, dm.val_ds, tta = 10, limit=0.1)

acc 1.00000:   0%|          | 1/5411 [00:00<09:41,  9.30it/s]

Training samples:  21642
Validation samples:  5411


acc 0.84686:  10%|▉         | 541/5411 [01:04<09:36,  8.44it/s]


0.8468634686346863

In [119]:
class FinalModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.trans = torch.nn.Sequential(
            torchvision.transforms.CenterCrop(256),
            torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        )
        self.trans_tta = torch.nn.Sequential(
            torchvision.transforms.RandomCrop(256),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.RandomVerticalFlip(),
            torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        )
    
    def forward(self, x, tta : int = 0):
        x = x.float() / 255.
        x = x.permute(2, 0, 1)
        if tta == 0:
            imgs = self.trans(x).unsqueeze(0)
            y_hat = self.model(imgs)[0]
        else:
            imgs = torch.stack([self.trans_tta(x) for i in range(tta)])
            y_hat = self.model(imgs).mean(dim=0)
        return torch.argmax(y_hat)

In [120]:
def evaluate_script(model, ds, tta = 0, limit = 1.):   
    model.eval()
    #model.cuda()
    acc = []
    inv_normalize = torchvision.transforms.Normalize(
        mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
        std=[1/0.229, 1/0.224, 1/0.255]
    )
    with torch.no_grad():
        t = tqdm(range(len(ds)))
        for ix in t:
            x, y = ds[ix]
            #x, y = x.cuda(), y.cuda()
            # simulate test
            x = inv_normalize(x)
            x *= 255. 
            x = x.permute(1, 2, 0).long()
            #print(x.shape, x.dtype, x.max(), x.min())
            y_hat= model(x, tta)
            acc.append((y_hat == y).item())
            t.set_description(f"acc {np.mean(acc):.5f}")
            
            if ix >= int(limit*len(ds)):
                break
    return np.mean(acc)

In [121]:
final_model = FinalModel(model.m.cpu())

dm = DataModule(
    file = 'data_extra', 
    val_trans=None
)
dm.setup()

evaluate_script(final_model, dm.val_ds, limit=0.1)

acc 1.00000:   0%|          | 3/5411 [00:00<03:08, 28.69it/s]

Training samples:  21642
Validation samples:  5411


acc 0.80996:  10%|▉         | 541/5411 [00:16<02:26, 33.25it/s]


0.8099630996309963

In [122]:
script = torch.jit.script(final_model.cpu())
torch.jit.save(script, "model.pt")

In [123]:
loaded = torch.jit.load('model.pt')

In [124]:
evaluate_script(loaded, dm.val_ds, limit=0.1)

acc 0.80996:  10%|▉         | 541/5411 [00:15<02:22, 34.10it/s]


0.8099630996309963

In [69]:
evaluate_script(loaded, dm.val_ds, tta=10, limit=0.1)

acc 0.84793:   4%|▍         | 217/5411 [01:00<24:13,  3.57it/s]


KeyboardInterrupt: 

In [125]:
path = Path('./data/test_images')
images = os.listdir(path)
images_paths = [str(path/img) for img in images]
len(images)

1

In [126]:
def predict(model, imgs, tta=0, bs=32):   
    model.eval()
    #model.cuda()
    preds = []
    with torch.no_grad():
        for img in imgs:
            img = torch.from_numpy(io.imread(img))#.cuda()
            y_hat = model(img, tta)#.cpu()
            preds.append(y_hat.item())
    return preds

In [127]:
loaded = torch.jit.load('model.pt')
preds = predict(loaded, images_paths)
preds

[2]

In [128]:
submission = pd.DataFrame({'image_id': images, 'label': preds })
submission

Unnamed: 0,image_id,label
0,2216849948.jpg,2


In [129]:
submission.to_csv('submission.csv', index=False)