In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision import transforms
from torchvision.datasets import CIFAR10
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl


class ImagenetTransferLearning(pl.LightningModule):
    def __init__(self, data_path, batch_size, lr):
        super().__init__()
        
        self.data_path = data_path
        self.batch_size = batch_size
        self.lr = lr
        

        # Loss function
        self.loss_fn = nn.CrossEntropyLoss()
        
        # init a pretrained resnet
        backbone = models.resnet18(pretrained=True)
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)

        # use the pretrained model to classify food101
        self.classifier = nn.Linear(num_filters, 101)

    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        x = self.classifier(representations)
        
        return x
        
    def training_step(self, batch, batch_idx):
        input, target = batch
        output = self(input)
        loss = self.loss_fn(output, target)
        
        self.log("train_loss", loss)
        
        return loss
        
    def validation_step(self, batch, batch_idx):
        input, target = batch
        output = self(input)
        loss = self.loss_fn(output, target)
        
        self.log("val_loss", loss)
        
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)
        
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset,
                                           batch_size=self.batch_size,
                                           num_workers=8,
                                           shuffle=True)
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset,
                                           batch_size=self.batch_size,
                                           num_workers=8,
                                           shuffle=False)


In [10]:
import pytorch_lightning as pl
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
from torchvision.datasets import Food101
import torch
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization

class Food101(pl.LightningDataModule):
    def prepare_data(self):
        transform=transforms.Compose([
            transforms.RandomResizedCrop(128),
            transforms.RandomCrop(128),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5],
                                 [0.229, 0.224, 0.225])])
      
        self.train_data = datasets.Food101('C/Users/jasti/Downloads/data', split="train", download=True, transform=transform)
        self.val_data = datasets.Food101('C/Users/jasti/Downloads/data', split="test", download=True, transform=transform)
        self.test_data = datasets.Food101('C/Users/jasti/Downloads/data', split="test", download=True, transform=transform)
        print('[INFO] data loaded successfully')

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size= 32, shuffle=True, num_workers=8)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size= 32, shuffle=True, num_workers=8)
    
    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size= 32, shuffle=True, num_workers=8)


In [11]:
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    mode="min"
)

Food101 = Food101()
model = ImagenetTransferLearning("data/cifar10/", 32, 1e-3)

trainer = pl.Trainer(callbacks=[checkpoint_callback], max_epochs=5, num_sanity_val_steps=0)
trainer.fit(model,Food101)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Downloading https://data.vision.ee.ethz.ch/cvl/food-101.tar.gz to C/Users/jasti/Downloads/data/food-101.tar.gz


  0%|          | 0/4996278331 [00:00<?, ?it/s]

Extracting C/Users/jasti/Downloads/data/food-101.tar.gz to C/Users/jasti/Downloads/data


INFO:pytorch_lightning.callbacks.model_summary:
  | Name              | Type             | Params
-------------------------------------------------------
0 | loss_fn           | CrossEntropyLoss | 0     
1 | feature_extractor | Sequential       | 11.2 M
2 | classifier        | Linear           | 51.8 K
-------------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.913    Total estimated model params size (MB)


[INFO] data loaded successfully


  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


In [13]:
trainer.validate(model, Food101)

[INFO] data loaded successfully


Validation: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        val_loss            2.2584116458892822
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_loss': 2.2584116458892822}]