In [1]:
from torchvision import datasets, transforms, models

from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.callbacks import ModelCheckpoint

import torch
from torch.nn import functional
from torch.utils.data import DataLoader, random_split

In [2]:
class ClassifyByCatDM(LightningDataModule):
    def __init__(self, setupdir, train_frac=0.9, seed=0, batch_size=64):
        
        super().__init__()

        self.batch_size = batch_size
        self.setupdir = setupdir
        self.train_frac = train_frac
        self.seed = seed
        self.batch_size = batch_size
        
        self.transform = transforms.Compose([
              transforms.Resize(size=256),
              transforms.CenterCrop(size=224),
              transforms.ToTensor(),
              transforms.Normalize([0.485, 0.456, 0.406],
                                   [0.229, 0.224, 0.225])
        ])

    def setup(self):
        
        torch.manual_seed(self.seed)
        
        dataset = datasets.ImageFolder(self.setupdir)
        self.num_classes = len(dataset.classes)
        
        set_len = len(dataset)
        train_len = int(set_len * self.train_frac)
        val_len = int(set_len * (1 - self.train_frac) / 2)
        test_len = set_len - train_len - val_len
        
        self.train, self.val, self.test = random_split(dataset, 
                                                      [train_len,
                                                       val_len,
                                                       test_len])
        self.train.dataset.transform = self.transform
        
        self.val.dataset.transform = self.transform
        
        self.test.dataset.transform = self.transform
        
    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)

In [3]:
class ClassifyModel(LightningModule):
    def __init__(self, input_shape, num_classes,
                 learning_rate = 1e-4, batch_size=64):
        
        super().__init__()

        self.batch_size = batch_size
        
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.dim = input_shape
        self.num_classes = num_classes
        
        self.feature_extractor = models.resnet34(pretrained=True)
        self.feature_extractor.eval()
        
        n_sizes = self._get_conv_output(input_shape)
        self.classifier = torch.nn.Linear(n_sizes, num_classes)
        
        self.predictions = []

    def _get_conv_output(self, shape):
        
        batch_size = 1
        inp = torch.autograd.Variable(torch.rand(batch_size, *shape))
        
        features = self._forward_features(inp)
        n_size = features.data.view(batch_size, -1).size(1)
        return n_size
    
    def _forward_features(self, x):
        
        x = self.feature_extractor(x)
        return x
    
    def forward(self, x):
        
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = functional.log_softmax(self.classifier(x), dim=1)
        
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = functional.nll_loss(logits, y)

        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)        

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = functional.nll_loss(logits, y)

        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = functional.nll_loss(logits, y)

        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        
        for i in range(len(y)):
            self.predictions.append(preds[i])

        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [4]:
batch_size = 2
dm = ClassifyByCatDM(setupdir='small_dataset_sorted_by_cat', train_frac=0.5,
                  seed=0, batch_size=batch_size)
dm.setup()

In [5]:
num_classes = dm.num_classes
model = ClassifyModel((3,224,224), num_classes,
                      batch_size=batch_size, learning_rate=2e-4)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


In [6]:
trainer = Trainer(max_epochs=4,
                  progress_bar_refresh_rate=1)

trainer.fit(model, dm)

trainer.test()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_deprecation(

  | Name              | Type   | Params
---------------------------------------------
0 | feature_extractor | ResNet | 21.8 M
1 | classifier        | Linear | 5.0 K 
---------------------------------------------
21.8 M    Trainable params
0         Non-trainable params
21.8 M    Total params
87.211    Total estimated model params size (MB)


Validation sanity check:   0%|                            | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(
  stream(template_mgs % msg_args)


                                                                                

  rank_zero_warn(
  rank_zero_warn(


Epoch 0:  67%|██████████▋     | 6/9 [00:01<00:00,  4.48it/s, loss=2.55, v_num=5]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                         | 0/3 [00:00<?, ?it/s][A
Epoch 0: 100%|█| 9/9 [00:01<00:00,  5.49it/s, loss=2.55, v_num=5, val_loss=2.490[A
Epoch 1:  67%|▋| 6/9 [00:01<00:00,  4.63it/s, loss=2.24, v_num=5, val_loss=2.490[A
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                         | 0/3 [00:00<?, ?it/s][A
Epoch 1: 100%|█| 9/9 [00:01<00:00,  5.66it/s, loss=2.24, v_num=5, val_loss=2.280[A
Epoch 2:  67%|▋| 6/9 [00:01<00:00,  4.51it/s, loss=1.98, v_num=5, val_loss=2.280[A
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                         | 0/3 [00:00<?, ?it/s][A
Epoch 2: 100%|█| 9/9 [00:01<00:00,  5.46it/s, loss=1.98, v_num=5, val_loss=1.680[A
Epoch 3:  67%|▋| 6/9 [00:01<00:00,  4.43it/s, loss=1.5, v_num=5, val_loss=1.680,[A
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          

  rank_zero_deprecation(
  rank_zero_warn(



Testing:  75%|███████████████████████████         | 3/4 [00:00<00:00, 11.59it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.2857142984867096, 'test_loss': 2.3450584411621094}
--------------------------------------------------------------------------------
Testing: 100%|████████████████████████████████████| 4/4 [00:00<00:00, 12.65it/s]


[{'test_loss': 2.3450584411621094, 'test_acc': 0.2857142984867096}]