In [5]:
import os
import time
import copy
import numpy as np
import matplotlib.pyplot as plt
import torch

import torch.nn as nn
import torch.optim as optim

from torchvision import datasets, models, transforms

In [7]:
ddir = '/content/drive/MyDrive/study/study by lecture/implementation/dataset/hym_data'
batch_size = 64
num_workers = 2

data_transformers = {
    'train' : transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.490, 0.449, 0.411],[0.231, 0.221, 0.230])        
    ]),
    'val' : transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([],[])
    ])
}


img_data = {
    k: datasets.ImageFolder(os.path.join(ddir, k), data_transformers[k]) for k in ['train', 'val']
}

dloaders = {
    k: torch.utils.data.DataLoader(img_data[k], batch_size = batch_size, shuffle = True, num_workers = num_workers) for k in ['train', 'val']
}

dset_size = {x : len(img_data[x]) for x in ['train', 'val']}
classes = img_data['train'].classes

dvc = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [9]:
from torch.nn.modules.batchnorm import BatchNorm2d
import time
import torch.nn as nn

class BasicBlock(nn.Module):
  def __init__(self, in_channels, out_channels, stride = 1):
    super().__init__()
    
    self.residual_function = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1, bias = False),
        nn.BatchNorm2d(out_channels),
        nn.ReLu(inplace = True),
        nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1, bias = False),
        nn.BatchNorm2d(out_channels)
    )

    self.shorcut = nn.Sequential()

    if stride != 1 or in_channels != out_channels:
      self.shortcut = nn.Sequential(
          nn.Conv2d(in_channels, out_channels, kernel_size=1, stride = stride, bias = False),
          nn.BatchNorm2d(out_channels)
      )
