In [43]:
import os
import torch
import json
from PIL import Image
from torchvision import transforms

In [34]:
import torch.nn as nn
import pytorch_lightning as pl
from torch.nn import functional as F
from torchmetrics.functional.classification import accuracy
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim import Adam
from torchvision import models
from typing import Optional
from torch.nn import Module

BN_TYPES = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)

def _make_trainable(module: Module) -> None:
    '''Unfreezes a given module.
    Args:
        module: The module to unfreeze
    '''
    for param in module.parameters():
        param.requires_grad = True
    module.train()


def _recursive_freeze(module: Module,
                      train_bn: bool = True) -> None:
    '''Freezes the layers of a given module.
    Args:
        module: The module to freeze
        train_bn: If True, leave the BatchNorm layers in training mode
    '''
    children = list(module.children())
    if not children:
        if not (isinstance(module, BN_TYPES) and train_bn):
            for param in module.parameters():
                param.requires_grad = False
            module.eval()
        else:
            # Make the BN layers trainable
            _make_trainable(module)
    else:
        for child in children:
            _recursive_freeze(module=child, train_bn=train_bn)


def freeze(module: Module,
           n: Optional[int] = None,
           train_bn: bool = True) -> None:
    '''Freezes the layers up to index n (if n is not None).
    Args:
        module: The module to freeze (at least partially)
        n: Max depth at which we stop freezing the layers. If None, all
            the layers of the given module will be frozen.
        train_bn: If True, leave the BatchNorm layers in training mode
    '''
    children = list(module.children())
    n_max = len(children) if n is None else int(n)

    for child in children[:n_max]:
        _recursive_freeze(module=child, train_bn=train_bn)

    for child in children[n_max:]:
        _make_trainable(module=child)

class ResNet152(pl.LightningModule):

    def __init__(self, 
                train_bn: bool = True,
                lr: float = 1e-3,
                num_workers: int = 4,
                hidden_1: int = 1024,
                hidden_2: int = 512,
                epoch_freeze: int = 8,
                total_steps: int = 15,
                pct_start: float = 0.2,
                anneal_strategy: str = 'cos',
                **kwargs):
        super().__init__()
        self.train_bn = train_bn
        self.lr = lr
        self.num_workers = num_workers
        self.hidden_1 = hidden_1
        self.hidden_2 = hidden_2
        self.epoch_freeze = epoch_freeze
        self.total_steps = total_steps
        self.pct_start = pct_start
        self.anneal_strategy = anneal_strategy
        self.save_hyperparameters()
        self.__build_model()
        
    def __build_model(self):
        num_target_classes = 196
        backbone = models.resnet152(pretrained=True)
    
        _layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*_layers)

        _fc_layers = [nn.Linear(2048, self.hidden_1),
                     nn.Linear(self.hidden_1, self.hidden_2),
                     nn.Linear(self.hidden_2, num_target_classes)]
        self.fc = nn.Sequential(*_fc_layers)

    def forward(self, x):
        x = self.feature_extractor(x)
        x = x.squeeze(-1).squeeze(-1)
        x = self.fc(x)
        return x
    
    def train(self, mode=True):
        super().train(mode=mode)
        epoch = self.current_epoch
        if epoch < self.epoch_freeze and mode:
            freeze(module=self.feature_extractor,
                   train_bn=self.train_bn) 
            
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_logits = self.forward(x)
        train_loss = F.cross_entropy(y_logits, y)
        acc = accuracy(y_logits, y)
        self.log('acc', acc, prog_bar=True)
        return train_loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_logits = self.forward(x)
        val_loss = F.cross_entropy(y_logits, y)
        acc = accuracy(y_logits, y)
        self.log('val_loss', val_loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def configure_optimizers(self):
        if self.current_epoch < self.epoch_freeze:
            optimizer = Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.lr)
            return optimizer
        else:
            optimizer = Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.lr)     
            scheduler = OneCycleLR(optimizer,
                            max_lr=self.lr,
                            total_steps=self.total_steps,
                            pct_start=self.pct_start, anneal_strategy=self.anneal_strategy)
        return [optimizer], [scheduler]

In [35]:
model = ResNet152()


In [36]:
model_dir = './model'
model.load_state_dict(torch.load(model_dir + '/model.pth'))

<All keys matched successfully>

In [37]:
model.eval()

In [26]:
image = Image.open('sedan.jpg').convert('RGB')

In [27]:
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
preprocess = transforms.Compose([
        transforms.Resize((400, 400)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std, inplace=True)
    ])

image_preprocessed = preprocess(image)
batch_image_tensor = torch.unsqueeze(image_preprocessed, 0)

In [38]:
output = model(batch_image_tensor)

In [39]:
output

tensor([[-0.1393, -0.1973,  0.3063, -0.1849,  0.2180, -0.0921,  0.1180,  2.7630,
          0.0236,  0.0516, -0.2295, -0.3622,  0.0693,  0.0716, -0.5010,  0.1693,
          0.2861, -0.3926,  0.0649,  0.2917, -0.9888,  0.2420, -0.6086,  0.3551,
          0.3682,  0.1606,  0.1521, -0.0119, -0.3442,  0.2192, -0.3687, -0.2173,
         -0.1415, -0.3838, -0.7095, -0.1164, -0.5318, -0.1630, -0.0683, -0.1332,
          0.4097, -0.7286,  0.3359, -0.3529,  0.3977, -0.3562, -0.1357, -0.1311,
         -0.0160,  0.2227, -0.6012, -0.3036, -0.1461,  0.2659,  0.0095, -0.4154,
          0.0045, -0.0515, -0.1052,  0.0067, -0.4168,  0.2203,  0.0409,  0.0870,
         -0.1089,  0.4164,  0.4776, -0.3142, -0.3883, -0.0778, -0.3655,  0.0586,
          0.1143, -0.1125,  0.3734, -0.0491,  0.2777,  0.1619, -0.2813, -0.0590,
         -0.1060,  0.0215,  0.0240,  0.0913,  0.1136,  0.5203, -0.4383, -0.6874,
         -0.3167,  0.1992,  0.0772,  0.1687, -0.1164,  0.1934,  0.1289, -0.5103,
         -0.2795, -0.4316,  

In [41]:
_, index = torch.max(output, 1)

In [45]:
classes = {}
with open(model_dir + '/classes.json') as f:
    classes = json.load(f)

In [None]:
def getKeyByValue(classes, index):
    listOfKeys = list()
    listOfItems = dictOfElements.items()
    for item  in listOfItems:
        if item[1] == valueToFind:
            listOfKeys.append(item[0])
    return  listOfKeys

In [52]:
int(index[0])

110

In [54]:
for item in classes:
    if classes[item] == int(index[0]):
        print(item)

Ford F-450 Super Duty Crew Cab 2012
