In [3]:
import pytorch_lightning as pl
from src import DataModule, Resnet, Efficientnet
import torch
from pytorch_lightning.metrics.functional.classification import accuracy
from tqdm import tqdm
import numpy as np

In [2]:
checkpoint = torch.load('resnet18-512-val_acc=0.85654.ckpt')
checkpoint['hyper_parameters']

{'lr': 0.0003,
 'optimizer': 'Adam',
 'batch_size': 256,
 'max_epochs': 50,
 'precision': 16,
 'subset': 0,
 'test_size': 0.2,
 'seed': 42,
 'size': 512,
 'backbone': 'resnet18'}

In [4]:
model = Efficientnet.load_from_checkpoint('efficientnet-b0-256-val_acc=0.81682.ckpt')
model.hparams

Loaded pretrained weights for efficientnet-b0


"backbone":   efficientnet-b0
"batch_size": 128
"lr":         0.0003
"max_epochs": 50
"optimizer":  Adam
"precision":  16
"seed":       42
"size":       256
"subset":     0
"test_size":  0.2

In [4]:
def evaluate(model, dl):   
    model.eval()
    model.cuda()
    acc = []
    with torch.no_grad():
        t = tqdm(dl)
        for x, y in t:
            x, y = x.cuda(), y.cuda()
            y_hat = model(x)
            acc.append(accuracy(y_hat, y).item())
            t.set_description(f"acc {np.mean(acc):.5f}")

In [5]:
dm = DataModule()
dm.setup()
evaluate(model, dm.val_dataloader())

  0%|          | 0/67 [00:00<?, ?it/s]

Training samples:  17117
Validation samples:  4280


acc 0.85608: 100%|██████████| 67/67 [00:19<00:00,  3.46it/s]


In [5]:
class Preprocess(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        x = x.float() / 255.
        x = x.permute(0, 3, 1, 2)
        return x 
    
class Postprocess(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return torch.argmax(x, dim=1)

In [None]:
script = torch.jit.trace(torch.nn.Sequential(
    Preprocess(),
    model.en.cpu(),
    Postprocess()
), torch.randn((64, 600, 800, 3)))
torch.jit.save(script, "enb0.pt")

In [None]:
def evaluate2(model, dl):   
    model.eval()
    model.cuda()
    acc = []
    with torch.no_grad():
        t = tqdm(dl)
        for x, y in t:
            x, y = x.cuda(), y.cuda()
            # simulate test
            x *= 255. 
            x = x.permute(0, 2, 3, 1).long()
            y_hat = model(x)
            acc.append(accuracy(y_hat, y).item())
            t.set_description(f"acc {np.mean(acc):.5f}")

In [9]:
loaded = torch.jit.load('resnet18.pt')
evaluate2(loaded, dm.val_dataloader())

acc 0.85608: 100%|██████████| 67/67 [00:18<00:00,  3.60it/s]
