In [None]:
!pip install torch torchvision

In [None]:
!pip install pytorch-lightning

In [None]:
!pip3 install ipywidgets

In [None]:
!jupyter nbextension enable --py widgetsnbextension

In [None]:
!pip3 install comet_ml

A tanítás eredényei itt tekinthetőek meg: https://www.comet.ml/dl4cv/hf4/view/WVhG6as2RKTQhHRkl3gBEnRRf. A magas tömörítésű input rosszabb eredményt ért el.

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from pytorch_lightning.core.lightning import LightningModule

import numpy as np
import cv2

In [2]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

jpeg_blur_counter = 0
def get_jpeg_blur(quality):
    def jpeg_blur(img):
        img = np.array(img)
        encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
        result, encimg = cv2.imencode('.jpg', img, encode_param)
        return cv2.imdecode(encimg, 1)
    return jpeg_blur

In [6]:
import torch.optim as optim

class LitClassifier(LightningModule):
    def __init__(self, quality):
        super().__init__()
        self.quality = quality
        self.net = Net()
        if quality is not None:
            self.transform = transforms.Compose(
                [transforms.Lambda(get_jpeg_blur(quality)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ])
        else:
            self.transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                ])

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        return optim.SGD(self.parameters(), lr=0.001, momentum=0.9)

    def train_dataloader(self):
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                                download=True, transform=self.transform)
        return torch.utils.data.DataLoader(trainset, batch_size=4,
                                                  shuffle=True, pin_memory=True, num_workers=2)
        
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        pred = y_hat.argmax(
            1, keepdim=True
        )  # get the index of the max log-probability
        acc = pred.eq(y.view_as(pred)).double()
        return {'val_loss': F.cross_entropy(y_hat, y), 'val_acc': acc}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['val_acc'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss, 'val_acc': avg_acc}
        print(f'Validation accuracy: {avg_acc}')
        return {'val_loss': avg_loss, 'val_acc': avg_acc, 'log': tensorboard_logs}

    def val_dataloader(self):
        testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                              download=True, transform=self.transform)
        testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                                   shuffle=False, pin_memory=True)
        return testloader


In [7]:
from pytorch_lightning import Trainer
from pytorch_lightning.logging import CometLogger
import os

comet_logger = CometLogger(api_key='cgss7piePhyFPXRw1J2uUEjkQ', workspace='dl4cv', project_name='HF4', experiment_name='original')

model = LitClassifier(None)

trainer = Trainer(gpus=1, min_epochs=10, logger=comet_logger)
r = trainer.fit(model)



# classes = ('plane', 'car', 'bird', 'cat',
#            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

CometLogger will be initialized in online mode
COMET INFO: Experiment is live on comet.ml https://www.comet.ml/dl4cv/hf4/08d9485b23ed4bc6aef4c7828a1d824b

GPU available: True, used: True
No environment variable for node rank defined. Set as 0.
CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type      | Params
------------------------------------
0 | net       | Net       | 62 K  
1 | net.conv1 | Conv2d    | 456   
2 | net.pool  | MaxPool2d | 0     
3 | net.conv2 | Conv2d    | 2 K   
4 | net.fc1   | Linear    | 48 K  
5 | net.fc2   | Linear    | 10 K  
6 | net.fc3   | Linear    | 850   


Files already downloaded and verified


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Validation accuracy: 0.125
Files already downloaded and verified
Files already downloaded and verified


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.496499987457355


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5521999860502547


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.581199985317653


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5978999848957756


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6034999847543077


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6107999845698941


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6172999844056903


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6178999843905331


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6267999841657002


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6241999842313817


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6239999842364341


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6227999842667487


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6243999842263293


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.621699984294537


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6236999842440127


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6111999845597893


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6113999845547369


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6008999848199892


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6145999844738981


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6144999844764243


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6040999847391504


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5908999850726104


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6067999846709426


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6039999847416766


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6044999847290455


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6010999848149368


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6093999846052611


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5953999849589309


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6079999846406281


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6063999846810475


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.607599984650733


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6096999845976825


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.6020999847896746


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5993999848578824


Detected KeyboardInterrupt, attempting graceful shutdown...


KeyboardInterrupt: 

In [9]:
comet_logger = CometLogger(api_key='cgss7piePhyFPXRw1J2uUEjkQ', workspace='DL4CV', project_name='HF4', experiment_name='compressed_20')

model = LitClassifier(20)

trainer = Trainer(gpus=1, min_epochs=10, logger=comet_logger)
r = trainer.fit(model)

CometLogger will be initialized in online mode
COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/dl4cv/hf4/08d9485b23ed4bc6aef4c7828a1d824b
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     epoch [43662]      : (0, 34)
COMET INFO:     train_loss [43628] : (4.76837158203125e-07, 5.2475175857543945)
COMET INFO:     val_acc [34]       : (0.496499987457355, 0.6267999841657002)
COMET INFO:     val_loss [34]      : (1.0986441373825073, 1.6192470788955688)
COMET INFO:   Others:
COMET INFO:     Name                       : original
COMET INFO:     experiment_stopped_by_user : True
COMET INFO:   Uploads:
COMET INFO:     environment details : 1
COMET INFO:     filename            : 1
COMET INFO:     installed packages  : 1
COMET INFO:     os packages         : 1
COMET INFO: ------------------------

Files already downloaded and verified


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Validation accuracy: 0.125
Files already downloaded and verified
Files already downloaded and verified


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.47299998805101495


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.528499986648967


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5421999863028759


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5658999857041636


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.572299985542486


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5678999856536393


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.578399985388387


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5791999853681773


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5833999852620764


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5831999852671288


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5818999852999696


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5633999857673189


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5809999853227055


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5740999854970141


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5778999854010181


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5800999853454414


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5744999854869093


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5753999854641734


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.570399985590484


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5697999856056413


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5675999856612179


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5551999859744683


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.557699985911313


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5606999858355266


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5613999858178431


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Validation accuracy: 0.5602999858456315


Detected KeyboardInterrupt, attempting graceful shutdown...


KeyboardInterrupt: 