In [None]:
from torchvision.models import resnet18,resnet34
from torch import nn
from torch.utils.data import DataLoader
import torch
import torchvision
import torch.nn as nn 
import torch.optim as optim
import torch.nn.functional as F 
import numpy as np
from torch.utils.data import  Dataset,DataLoader
from PIL import Image
import pytorch_lightning as pl
from pytorch_lightning.core.decorators import auto_move_data

In [None]:
!pip install pytorch_lightning

In [None]:
class KMNIST(Dataset):
  
  def __init__(self, train = True, root = ''):
    self.root = root
    
    self.X_train = torch.from_numpy(np.load(root+"/k49-train-imgs.npz")['arr_0']).reshape(-1,1,28,28)
    self.X_train = self.X_train / 255
    self.y_train = torch.from_numpy(np.load(root+"/k49-train-labels.npz")['arr_0'])
    self.y_train = self.y_train.type(torch.LongTensor)
    
    
    self.X_test = torch.from_numpy(np.load(root+"/k49-test-imgs.npz")['arr_0']).reshape(-1,1,28,28)
    self.X_test = self.X_test / 255
    self.y_test = torch.from_numpy(np.load(root+"/k49-test-labels.npz")['arr_0'])
    self.y_test = self.y_test.type(torch.LongTensor)
    
    self.train = train
  
  def __len__(self):
    return self.X_train.shape[0] if self.train else self.X_test.shape[0]
  
  def __getitem__(self,index):
    if self.train: 
      return self.X_train[index],self.y_train[index]
    
    else: 
      return self.X_test[index],self.y_test[index]
data_train1 = KMNIST(root='/content/drive/MyDrive/hiragana')
data_test2 = KMNIST(root='/content/drive/MyDrive/hiragana',train = False)

In [None]:
train_loader = DataLoader(data_train1, batch_size=64,
                                          shuffle=True)

test_loader = DataLoader(data_test2, batch_size=64,
                                         shuffle=False)

In [None]:
class ResNetKMNIST(pl.LightningModule):
  def __init__(self):
    super().__init__()
    self.model = resnet18(num_classes=49)
    self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    self.loss = nn.CrossEntropyLoss()

  @auto_move_data
  def forward(self, x):
    return self.model(x)
  
  def training_step(self, batch, batch_no):
    x, y = batch
    logits = self(x)
    loss = self.loss(logits, y)
    return loss
  
  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters())

In [None]:
model = ResNetKMNIST()

In [None]:
model

In [None]:
trainer = pl.Trainer(
    gpus=1,
    max_epochs=20,
    progress_bar_refresh_rate=20
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model, train_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params
-------------------------------------------
0 | model | ResNet           | 11.2 M
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.782    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




1

In [None]:
torch.save(model.state_dict(), 'resnet18_v3.pt')

In [None]:
def get_prediction(x, model: pl.LightningModule):
  model.freeze() # prepares model for predicting
  probabilities = torch.softmax(model(x), dim=1)
  predicted_class = torch.argmax(probabilities, dim=1)
  return predicted_class, probabilities

In [None]:
from tqdm.autonotebook import tqdm

In [None]:
m_state_dict = torch.load('resnet18_v3.pt')
inference_model = ResNetKMNIST()
inference_model.load_state_dict(m_state_dict)

<All keys matched successfully>

In [None]:
true_y, pred_y = [], []
for batch in tqdm(iter(test_loader), total=len(test_loader)):
  inference_model.eval()
  x, y = batch
  true_y.extend(y)
  preds, probs = get_prediction(x, inference_model)
  pred_y.extend(preds.cpu())

HBox(children=(FloatProgress(value=0.0, max=603.0), HTML(value='')))




In [None]:
from sklearn.metrics import classification_report

In [None]:
print(classification_report(true_y, pred_y, digits=3))

              precision    recall  f1-score   support

           0      0.960     0.972     0.966      1000
           1      0.981     0.982     0.982      1000
           2      0.959     0.964     0.962      1000
           3      0.870     0.952     0.909       126
           4      0.960     0.967     0.964      1000
           5      0.921     0.939     0.930      1000
           6      0.959     0.942     0.951      1000
           7      0.920     0.938     0.929      1000
           8      0.903     0.977     0.939       767
           9      0.944     0.966     0.955      1000
          10      0.960     0.973     0.966      1000
          11      0.982     0.944     0.963      1000
          12      0.955     0.926     0.940      1000
          13      0.967     0.937     0.951       678
          14      0.952     0.940     0.946       629
          15      0.957     0.974     0.965      1000
          16      0.967     0.974     0.970       418
          17      0.960    

In [None]:
model.state_dict

In [None]:
torch.save(model.state_dict(), 'mymodule.pt')