In [1]:
# !pip install pytorch-lightning

In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
from pytorch_lightning.core.lightning import LightningModule

import numpy as np
from PIL import Image

  assert isinstance(args, collections.Mapping), '{} args must be a dict with argument names as keys.'.format(name)
  from collections import OrderedDict, Sequence, defaultdict


In [3]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

jpeg_blur_counter = 0
def get_jpeg_blur(quality):
    def jpeg_blur(img):
        jpeg_blur_counter=0
        img.save(f"{jpeg_blur_counter}.jpg",quality=quality)
        img2 = Image.open(f"{jpeg_blur_counter}.jpg")
        jpeg_blur_counter += 1
        return img2
    return jpeg_blur

In [27]:
import torch.optim as optim

class LitClassifier(LightningModule):
    def __init__(self, quality):
        super().__init__()
        self.quality = quality
        self.net = Net()
        if quality is not None:
            self.transform = transforms.Compose(
                [transforms.Lambda(get_jpeg_blur(quality)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ])
        else:
            self.transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ])

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        return optim.SGD(self.parameters(), lr=0.001, momentum=0.9)

    def train_dataloader(self):
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                                download=True, transform=self.transform)
        return torch.utils.data.DataLoader(trainset, batch_size=4,
                                                  shuffle=True, pin_memory=True)
        
    def validation_step(self, batch, batch_idx, _):
        print(batch.shape)
        x, y = batch
        y_hat = self(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_epoch_end(self, outputs):
        print('val end')
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'val_loss': avg_loss, 'log': tensorboard_logs}

    def val_dataloader(self):
        testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                              download=True, transform=self.transform)
        testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                                   shuffle=False, pin_memory=True)
        print(type(testloader))
        return testloader


In [28]:
dl = model.val_dataloader()

Files already downloaded and verified
<class 'torch.utils.data.dataloader.DataLoader'>


In [29]:
next(iter(dl))

[tensor([[[[ 0.2392,  0.2471,  0.2941,  ...,  0.0745, -0.0118, -0.0902],
           [ 0.1922,  0.1843,  0.2471,  ...,  0.0667, -0.0196, -0.0667],
           [ 0.1843,  0.1843,  0.2392,  ...,  0.0902,  0.0196, -0.0588],
           ...,
           [-0.4667, -0.6706, -0.7569,  ..., -0.7020, -0.8980, -0.6863],
           [-0.5216, -0.6157, -0.7255,  ..., -0.7961, -0.7725, -0.8431],
           [-0.5765, -0.5608, -0.6471,  ..., -0.8118, -0.7333, -0.8353]],
 
          [[-0.1216, -0.1294, -0.0902,  ..., -0.2549, -0.2863, -0.3333],
           [-0.1216, -0.1373, -0.1059,  ..., -0.2549, -0.2863, -0.3098],
           [-0.1373, -0.1451, -0.1294,  ..., -0.2314, -0.2549, -0.3020],
           ...,
           [-0.0275, -0.2157, -0.3098,  ..., -0.2392, -0.4980, -0.3333],
           [-0.0902, -0.2000, -0.3333,  ..., -0.3569, -0.3569, -0.4980],
           [-0.1608, -0.1765, -0.3020,  ..., -0.3961, -0.3412, -0.4745]],
 
          [[-0.6157, -0.6314, -0.6000,  ..., -0.7176, -0.7176, -0.7412],
           [-

In [30]:
from pytorch_lightning import Trainer
from pytorch_lightning.logging import CometLogger
import os

comet_logger = None# CometLogger(api_key=os.environ.get('COMET_API_KEY'), workspace='dl4cv', project_name='HF4', experiment_name='original')

model = LitClassifier(None)

trainer = Trainer(gpus=1, min_epochs=10, logger=comet_logger)
r = trainer.fit(model)



# classes = ('plane', 'car', 'bird', 'cat',
#            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

INFO:pytorch_lightning.logging.comet:CometLogger will be initialized in online mode
COMET INFO: ----------------------------
COMET INFO: Comet.ml Experiment Summary:
COMET INFO:   Data:
COMET INFO:     url: https://www.comet.ml/dl4cv/hf4/e132b0efe1024ce7b28f6b0f8353a334
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     sys.cpu.percent.01 [2]       : (15.6, 15.9)
COMET INFO:     sys.cpu.percent.02 [2]       : (2.0, 3.3)
COMET INFO:     sys.cpu.percent.03 [2]       : (1.5, 4.7)
COMET INFO:     sys.cpu.percent.04 [2]       : (2.0, 4.5)
COMET INFO:     sys.cpu.percent.avg [2]      : (5.975, 6.3999999999999995)
COMET INFO:     sys.gpu.0.free_memory [3]    : (3757703168.0, 3757703168.0)
COMET INFO:     sys.gpu.0.gpu_utilization [3]: (0.0, 0.0)
COMET INFO:     sys.gpu.0.total_memory       : (4238999552.0, 4238999552.0)
COMET INFO:     sys.gpu.0.used_memory [3]    : (481296384.0, 481296384.0)
COMET INFO:     sys.load.avg [2]             : (0.21, 0.26)
COMET INFO:     sys.ram.total [2] 

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
<class 'torch.utils.data.dataloader.DataLoader'>
Files already downloaded and verified
<class 'torch.utils.data.dataloader.DataLoader'>
Files already downloaded and verified


INFO:root:
        Name       Type Params
0        net        Net   62 K
1  net.conv1     Conv2d  456  
2   net.pool  MaxPool2d    0  
3  net.conv2     Conv2d    2 K
4    net.fc1     Linear   48 K
5    net.fc2     Linear   10 K
6    net.fc3     Linear  850  


Files already downloaded and verified


Validation sanity check:   0%|          | 0/12500 [00:00<?, ?batch/s]

<class 'torch.utils.data.dataloader.DataLoader'>
Files already downloaded and verified
<class 'torch.utils.data.dataloader.DataLoader'>
Files already downloaded and verified
<class 'torch.utils.data.dataloader.DataLoader'>
torch.Size([4, 3, 32, 32])


ValueError: too many values to unpack (expected 2)

In [None]:
comet_logger = CometLogger(workspace='DL4CV', project_name='HF4', experiment_name='compressed_20')

model = LitClassifier(20)

trainer = Trainer(gpus=1, min_epochs=10, logger=comet_logger)
r = trainer.fit(model)