<a href="https://colab.research.google.com/github/mot1122/CNN_model/blob/main/use_pytorch/finetuned_resnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install pytorch-lightning==0.7.1

Collecting pytorch-lightning==0.7.1
  Downloading pytorch-lightning-0.7.1.tar.gz (6.0 MB)
[K     |████████████████████████████████| 6.0 MB 7.4 MB/s 
Collecting future>=0.17.1
  Downloading future-0.18.2.tar.gz (829 kB)
[K     |████████████████████████████████| 829 kB 56.5 MB/s 
Building wheels for collected packages: pytorch-lightning, future
  Building wheel for pytorch-lightning (setup.py) ... [?25l[?25hdone
  Created wheel for pytorch-lightning: filename=pytorch_lightning-0.7.1-py3-none-any.whl size=145329 sha256=311735bbb5f00921074e8710edf47c6175ec3aec6334a7b65e947f5b7e4c042f
  Stored in directory: /root/.cache/pip/wheels/a5/c0/6c/ed64904da20814878f410c520ae61c062c6d7e93bf5c27dcd4
  Building wheel for future (setup.py) ... [?25l[?25hdone
  Created wheel for future: filename=future-0.18.2-py3-none-any.whl size=491070 sha256=96c96bc412a1c9419789be855ae110979f14d9f4283d74ffabd0b478eedd7838
  Stored in directory: /root/.cache/pip/wheels/56/b0/fe/4410d17b32f1f0c3cf54cdfb2bc04d7b4b

In [2]:
import torch,torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning import Trainer

# Dataset

In [3]:
transform=transforms.Compose([
    transforms.ToTensor()
])

In [4]:
train_val=torchvision.datasets.CIFAR10(root="data",train=True,download=True,transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data/cifar-10-python.tar.gz to data


In [5]:
test=torchvision.datasets.CIFAR10(root="data",train=False,download=True,transform=transform)

Files already downloaded and verified


In [6]:
n_train=int(len(train_val)*0.8)
n_val=len(train_val)-n_train

In [7]:
torch.manual_seed(0)
train,val=torch.utils.data.random_split(train_val,[n_train,n_val])

In [8]:
len(train),len(val),len(test)

(40000, 10000, 10000)

# Model

In [9]:
class TrainNet(pl.LightningModule):
  def train_dataloader(self):
    return torch.utils.data.DataLoader(train,self.batch_size,shuffle=True)
  def training_step(self,batch,batch_nb):
    x,t=batch
    y=self.forward(x)
    loss=self.lossfun(y,t)
    results={"loss":loss}
    return results

In [10]:
class ValidationNet(pl.LightningModule):
  def val_dataloader(self):
        return torch.utils.data.DataLoader(val, self.batch_size)
  def validation_step(self,batch,batch_nb):
    x,t=batch
    y=self.forward(x)
    loss=self.lossfun(y,t)
    y_label=torch.argmax(y,dim=1)
    acc=torch.sum(t==y_label)/len(t)
    results={"val_loss":loss,"val_acc":acc}
    return results
  def validation_end(self,outputs):
    avg_loss=torch.stack([x["val_loss"] for x in outputs]).mean()
    avg_acc=torch.stack([x["val_acc"] for x in outputs]).mean()
    results={"val_loss":avg_loss,"val_acc":avg_acc}
    return results

In [11]:
class TestNet(pl.LightningModule):
    def test_dataloader(self):
        return torch.utils.data.DataLoader(test, self.batch_size)

    def test_step(self, batch, batch_nb):
        x, t = batch
        y = self.forward(x)
        loss = self.lossfun(y, t)
        y_label = torch.argmax(y, dim=1)
        acc = torch.sum(t == y_label) * 1.0 / len(t)
        results = {'test_loss': loss, 'test_acc': acc}
        return results

    def test_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['test_acc'] for x in outputs]).mean()
        results = {'test_loss': avg_loss, 'test_acc': avg_acc}
        return results

In [12]:
class Net(TrainNet,ValidationNet,TestNet):
  def __init__(self,batch_size=256):
    super().__init__()
    self.batch_size=batch_size
    self.conv1=nn.Conv2d(3,64,3,padding=1)
    self.bn1=nn.BatchNorm2d(64)
    self.conv2=nn.Conv2d(64,128,3,padding=1)
    self.bn2=nn.BatchNorm2d(128)
    self.conv3=nn.Conv2d(128,256,3,padding=1)
    self.bn3=nn.BatchNorm2d(256)
    self.conv4=nn.Conv2d(256,512,3,padding=1)
    self.bn4=nn.BatchNorm2d(512)
    self.fc=nn.Linear(2048,10)
  def lossfun(self,y,t):
    return F.cross_entropy(y,t)
  def configure_optimizers(self):
    return torch.optim.SGD(self.parameters(),lr=0.01)
  def forward(self,x):
    x=self.conv1(x)
    x=self.bn1(x)
    x=F.relu(x)
    x=F.max_pool2d(x,2,2)

    x=self.conv2(x)
    x=self.bn2(x)
    x=F.relu(x)
    x=F.max_pool2d(x,2,2)

    x=self.conv3(x)
    x=self.bn3(x)
    x=F.relu(x)
    x=F.max_pool2d(x,2,2)

    x=self.conv4(x)
    x=self.bn4(x)
    x=F.relu(x)
    x=F.max_pool2d(x,2,2)

    x=x.view(-1,2048)
    x=self.fc(x)
    
    return x

In [20]:
torch.backends.cudnn.deterministic=True
torch.backends.cudnn.benchmark=False
torch.manual_seed(0)

net=Net()
trainer=Trainer(gpus=1,max_epochs=10,batch_size=1024)
trainer.fit(net)

Validation sanity check:   0%|          | 0/5 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Validating:   0%|          | 0/40 [00:00<?, ?it/s]

Validating:   0%|          | 0/40 [00:00<?, ?it/s]

Validating:   0%|          | 0/40 [00:00<?, ?it/s]

Validating:   0%|          | 0/40 [00:00<?, ?it/s]

Validating:   0%|          | 0/40 [00:00<?, ?it/s]

Validating:   0%|          | 0/40 [00:00<?, ?it/s]

Validating:   0%|          | 0/40 [00:00<?, ?it/s]

Validating:   0%|          | 0/40 [00:00<?, ?it/s]

Validating:   0%|          | 0/40 [00:00<?, ?it/s]

Validating:   0%|          | 0/40 [00:00<?, ?it/s]

1

In [21]:
trainer.test()
trainer.callback_metrics

Testing:   0%|          | 0/40 [00:00<?, ?it/s]

----------------------------------------------------------------------------------------------------
TEST RESULTS
{}
----------------------------------------------------------------------------------------------------


{'epoch': 9,
 'loss': 0.8237280249595642,
 'test_acc': 0.754589855670929,
 'test_loss': 0.7008712887763977,
 'val_acc': 0.746289074420929,
 'val_loss': 0.7219057083129883}

# Resnet18

In [22]:
from torchvision.models import resnet18

In [23]:
resnet=resnet18(pretrained=True)

In [24]:
transform=transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Dataset

In [25]:
train_val=torchvision.datasets.CIFAR10(root="data",train=True,download=True,transform=transform)
test=torchvision.datasets.CIFAR10(root="data",train=False,download=True,transform=transform)

n_train=int(len(train_val)*0.8)
n_val=len(train_val)-n_train

torch.manual_seed(0)
train,val=torch.utils.data.random_split(train_val,[n_train,n_val])

len(train),len(val),len(test)

Files already downloaded and verified
Files already downloaded and verified


(40000, 10000, 10000)

In [26]:
class Net(TrainNet,ValidationNet, TestNet):
  def __init__(self,batch_size=256):
    super().__init__()
    self.batch_size=batch_size
    self.conv=resnet
    self.fc=nn.Linear(1000,10)
    for param in self.conv.parameters():
      param.requires_grad=False
      
  def lossfun(self, y, t):
      return F.cross_entropy(y, t)

  def configure_optimizers(self):
      return torch.optim.SGD(self.parameters(), lr=0.01)

  def forward(self, x):
      x = self.conv(x)
      x = self.fc(x)
      return x

In [30]:
torch.backends.cudnn.deterministic=True
torch.backends.cudnn.benchmark=False
torch.manual_seed(0)

net=Net()
trainer=Trainer(gpus=1,max_epochs=10,batch_size=1024)
trainer.fit(net)

Validation sanity check:   0%|          | 0/5 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Validating:   0%|          | 0/40 [00:00<?, ?it/s]

Validating:   0%|          | 0/40 [00:00<?, ?it/s]

Validating:   0%|          | 0/40 [00:00<?, ?it/s]

Validating:   0%|          | 0/40 [00:00<?, ?it/s]

Validating:   0%|          | 0/40 [00:00<?, ?it/s]

Validating:   0%|          | 0/40 [00:00<?, ?it/s]

Validating:   0%|          | 0/40 [00:00<?, ?it/s]

Validating:   0%|          | 0/40 [00:00<?, ?it/s]

Validating:   0%|          | 0/40 [00:00<?, ?it/s]

Validating:   0%|          | 0/40 [00:00<?, ?it/s]

1

In [31]:
trainer.test()
trainer.callback_metrics

Testing:   0%|          | 0/40 [00:00<?, ?it/s]

----------------------------------------------------------------------------------------------------
TEST RESULTS
{}
----------------------------------------------------------------------------------------------------


{'epoch': 9,
 'loss': 0.8237280249595642,
 'test_acc': 0.754589855670929,
 'test_loss': 0.7008712887763977,
 'val_acc': 0.746289074420929,
 'val_loss': 0.7219057083129883}