In [70]:
import pytorch_lightning as pl
import torch 
from pytorch_lightning.metrics.functional.classification import accuracy
from tqdm import tqdm
from src import DataModule, Model
import torchvision
import pandas as pd 
from pathlib import Path
from skimage import io

In [27]:
dm = DataModule(
    path = Path('data'), 
    batch_size=64
)
dm.setup()

Training samples:  17117
Validation samples:  4280


In [28]:
checkpoint = torch.load('./resnet50-val_acc=0.84112.ckpt')
hparams = checkpoint['hyper_parameters']
hparams

{'lr': 3e-05, 'batch_size': 128, 'max_epochs': 50, 'precision': 16}

In [29]:
model = Model.load_from_checkpoint(checkpoint_path='./resnet50-val_acc=0.84112.ckpt')
model.hparams

"batch_size": 128
"lr":         3e-05
"max_epochs": 50
"precision":  16

In [30]:
 def evaluate(model, dl):   
    model.eval()
    model.cuda()
    acc = []
    with torch.no_grad():
        t = tqdm(dl)
        for x, y in t:
            x, y = x.cuda(), y.cuda()
            y_hat = model(x)
            acc.append(accuracy(y_hat, y).item())
            t.set_description(f"acc {np.mean(acc):.5f}")
            
evaluate(model, dm.val_dataloader())

acc 0.84995: 100%|██████████| 67/67 [00:42<00:00,  1.59it/s]


In [85]:
class Preprocess(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        x = x.float() / 255.
        x = x.permute(0, 3, 1, 2)
        return x 
    
class Postprocess(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return torch.argmax(x, dim=1)

In [86]:
script = torch.jit.script(torch.nn.Sequential(
    Preprocess(),
    model.resnet.cpu(),
    Postprocess()
))
torch.jit.save(script, "model.pt")

In [87]:
 def evaluate2(model, dl):   
    model.eval()
    model.cuda()
    acc = []
    with torch.no_grad():
        t = tqdm(dl)
        for x, y in t:
            x, y = x.cuda(), y.cuda()
            # simulate test
            x *= 255. 
            x = x.permute(0, 2, 3, 1).long()
            #print(x.shape, x.dtype, x.max(), x.min())
            y_hat = model(x)
            acc.append(accuracy(y_hat, y).item())
            t.set_description(f"acc {np.mean(acc):.5f}")

In [88]:
loaded = torch.jit.load('model.pt')
evaluate2(loaded, dm.val_dataloader())

acc 0.84995: 100%|██████████| 67/67 [00:43<00:00,  1.55it/s]


In [91]:
path = Path('./data/test_images')
images = os.listdir(path)
images_paths = [str(path/img) for img in images]
len(images)

1

In [96]:
def predict(model, images, bs=32):   
    model.eval()
    model.cuda()
    preds = torch.tensor([]).cuda()
    batches = len(images) // bs + 1
    print(batches)
    with torch.no_grad():
        for b in tqdm(range(batches)):
            imgs = images[bs*b:bs*(b+1)]
            imgs = torch.from_numpy(np.array([io.imread(img) for img in imgs]))
            y_hat = model(imgs.cuda())
            preds = torch.cat([preds, y_hat])
    return preds.long().cpu().numpy()

In [97]:
loaded = torch.jit.load('model.pt')
preds = predict(loaded, images_paths)
preds

100%|██████████| 1/1 [00:00<00:00,  7.60it/s]

1





array([4])

In [98]:
submission = pd.DataFrame({'image_id': images, 'label': preds })
submission

Unnamed: 0,image_id,label
0,2216849948.jpg,4


In [99]:
submission.to_csv('submission.csv', index=False)