In [None]:
import torch
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
import os
import cv2
from PIL import Image
from tqdm import tqdm_notebook as tqdm
import random
from matplotlib import pyplot as plt
import time
from collections import namedtuple
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class ImageTransform():    
    def __init__(self, resize, mean, std):
        self.data_transform = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(resize, scale=(0.5, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]),
            'val': transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(resize),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
        }
        
    def __call__(self, img, phase):
        return self.data_transform[phase](img)

In [None]:
cat_directory = r'./Cat'
dog_directory = r'./Dog'

cat_images_filepaths = sorted([os.path.join(cat_directory, f) for f in os.listdir(cat_directory)])   
dog_images_filepaths = sorted([os.path.join(dog_directory, f) for f in os.listdir(dog_directory)])
images_filepaths = [*cat_images_filepaths, *dog_images_filepaths]    
correct_images_filepaths = [i for i in images_filepaths if cv2.imread(i) is not None]    

random.seed(42)    
random.shuffle(correct_images_filepaths)
train_images_filepaths = correct_images_filepaths[:400]    
val_images_filepaths = correct_images_filepaths[400:-10]  
test_images_filepaths = correct_images_filepaths[-10:]    
print(len(train_images_filepaths), len(val_images_filepaths), len(test_images_filepaths))

In [None]:
class DogvsCatDataset(Dataset):    
    def __init__(self, file_list, transform=None, phase='train'):    
        self.file_list = file_list
        self.transform = transform
        self.phase = phase
        
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):       
        img_path = self.file_list[idx]
        img = Image.open(img_path)        
        img_transformed = self.transform(img, self.phase)
        
        label = img_path.split('/')[-1].split('.')[0]
        if 'Dog' in label:
            label = 1
        elif 'Cat' in label :
            label = 0
        return img_transformed, label

In [None]:
size = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
batch_size = 32

In [None]:
train_dataset = DogvsCatDataset(train_images_filepaths, transform=ImageTransform(size, mean, std), phase='train')
val_dataset = DogvsCatDataset(val_images_filepaths, transform=ImageTransform(size, mean, std), phase='val')

index = 0
print(train_dataset.__getitem__(index)[0].size())
print(train_dataset.__getitem__(index)[1])

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
dataloader_dict = {'train': train_dataloader, 'val': val_dataloader}

batch_iterator = iter(train_dataloader)
inputs, label = next(batch_iterator)
print(inputs.size())
print(label)

In [None]:
class BasicBlock(nn.Module):#기본 블록
    expansion = 1
    
    def __init__(self, in_channels, out_channels, stride = 1, downsample = False):
        super().__init__()                
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, 
                               stride = stride, padding = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(out_channels)        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, 
                               stride = 1, padding = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(out_channels)        
        self.relu = nn.ReLU(inplace = True)
        
        if downsample:
            conv = nn.Conv2d(in_channels, out_channels, kernel_size = 1, 
                             stride = stride, bias = False)
            bn = nn.BatchNorm2d(out_channels)
            downsample = nn.Sequential(conv, bn)
        else:
            downsample = None        
        self.downsample = downsample
        
    def forward(self, x):       
        i = x       
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)        
        x = self.conv2(x)
        x = self.bn2(x)
        
        if self.downsample is not None:
            i = self.downsample(i)
                        
        x += i
        x = self.relu(x)
        
        return x

In [None]:
class Bottleneck(nn.Module): #병목 블록
    expansion = 4
    
    def __init__(self, in_channels, out_channels, stride = 1, downsample = False):
        super().__init__()    
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(out_channels)        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = stride, padding = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(out_channels)        
        self.conv3 = nn.Conv2d(out_channels, self.expansion * out_channels, kernel_size = 1,
                               stride = 1, bias = False)
        self.bn3 = nn.BatchNorm2d(self.expansion * out_channels)        
        self.relu = nn.ReLU(inplace = True)
        
        if downsample:
            conv = nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size = 1, 
                             stride = stride, bias = False)
            bn = nn.BatchNorm2d(self.expansion * out_channels)
            downsample = nn.Sequential(conv, bn)
        else:
            downsample = None            
        self.downsample = downsample
        
    def forward(self, x):        
        i = x        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)        
        x = self.conv3(x)
        x = self.bn3(x)
                
        if self.downsample is not None:
            i = self.downsample(i)
            
        x += i
        x = self.relu(x)
    
        return x

In [48]:
class Full_pre_activation_bottleneck(nn.Module): #Full_pre_activation을 적용한 병목 블록
    expansion = 4
    
    def __init__(self, in_channels, out_channels, stride = 1, downsample = False):
        super().__init__()    
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(in_channels)        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = stride, padding = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(out_channels)        
        self.conv3 = nn.Conv2d(out_channels, self.expansion * out_channels, kernel_size = 1,
                               stride = 1, bias = False)
        self.bn3 = nn.BatchNorm2d(out_channels)        
        self.relu = nn.ReLU(inplace = True)
        
        if downsample:
            conv = nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size = 1, 
                             stride = stride, bias = False)
            bn = nn.BatchNorm2d(self.expansion * out_channels)
            downsample = nn.Sequential(conv, bn)
        else:
            downsample = None            
        self.downsample = downsample
        
    def forward(self, x):        
        i = x
        x = self.bn1(x)
        x = self.relu(x)         
        x = self.conv1(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn3(x)
        x = self.relu(x)
        x = self.conv3(x)
        if self.downsample is not None:
            i = self.downsample(i)#만약 입력값과 출력값의 차원이 다를시에 시행하면됨
        x += i
        return x

In [49]:
class ResNet(nn.Module):
    def __init__(self, config, output_dim):
        super().__init__()
                
        block, n_blocks, channels = config
        self.in_channels = channels[0]            
        assert len(n_blocks) == len(channels) == 4
        
        self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size = 7, stride = 2, padding = 3, bias = False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace = True)
        self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
        
        self.layer1 = self.get_resnet_layer(block, n_blocks[0], channels[0])
        self.layer2 = self.get_resnet_layer(block, n_blocks[1], channels[1], stride = 2)
        self.layer3 = self.get_resnet_layer(block, n_blocks[2], channels[2], stride = 2)
        self.layer4 = self.get_resnet_layer(block, n_blocks[3], channels[3], stride = 2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(self.in_channels, output_dim)
        
    def get_resnet_layer(self, block, n_blocks, channels, stride = 1):   
        layers = []        
        if self.in_channels != block.expansion * channels:
            downsample = True
        else:
            downsample = False
        
        layers.append(block(self.in_channels, channels, stride, downsample))
        
        for i in range(1, n_blocks):
            layers.append(block(block.expansion * channels, channels))

        self.in_channels = block.expansion * channels            
        return nn.Sequential(*layers)
        
    def forward(self, x):        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)        
        x = self.avgpool(x)
        h = x.view(x.shape[0], -1)
        x = self.fc(h)        
        return x, h

In [50]:
ResNetConfig = namedtuple('ResNetConfig', ['block', 'n_blocks', 'channels'])

In [51]:
resnet18_config = ResNetConfig(block = BasicBlock,
                               n_blocks = [2,2,2,2],
                               channels = [64, 128, 256, 512])

resnet34_config = ResNetConfig(block = BasicBlock,
                               n_blocks = [3,4,6,3],
                               channels = [64, 128, 256, 512])
resnet50_config = ResNetConfig(block = Bottleneck,
                               n_blocks = [3, 4, 6, 3],
                               channels = [64, 128, 256, 512])

resnet101_config = ResNetConfig(block = Bottleneck,
                                n_blocks = [3, 4, 23, 3],
                                channels = [64, 128, 256, 512])

resnet152_config = ResNetConfig(block = Bottleneck,
                                n_blocks = [3, 8, 36, 3],
                                channels = [64, 128, 256, 512])
resnet50_config_full_pre_actiavation = ResNetConfig(block = Full_pre_activation_bottleneck,
                               n_blocks = [3, 4, 6, 3],
                               channels = [64, 128, 256, 512])

In [28]:
OUTPUT_DIM = 2
model = ResNet(resnet50_config, OUTPUT_DIM)
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [52]:
def calculate_accuracy(y_pred, y):#정확도 계산해주는 코드
    top_pred = y_pred.argmax(1, keepdim = True)
    correct = top_pred.eq(y.view_as(top_pred)).sum()
    acc = correct.float() / y.shape[0]
    return acc

def train(model, iterator, optimizer, criterion, device): #학습 함수
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()    
    for (x, y) in tqdm(iterator):        
        x = x.to(device)
        y = y.to(device)
        
        optimizer.zero_grad()                
        y_pred, _ = model(x)        
        loss = criterion(y_pred, y)       
        acc = calculate_accuracy(y_pred, y)        
        loss.backward()        
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def evaluate(model, iterator, criterion, device): #평가 함수
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()   
    with torch.no_grad():        
        for (x, y) in tqdm(iterator):
            x = x.to(device)
            y = y.to(device)
            y_pred, _ = model(x)
            loss = criterion(y_pred, y)
            acc = calculate_accuracy(y_pred, y)
            epoch_loss += loss.item()
            epoch_acc += acc.item()        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def epoch_time(start_time, end_time):#학습이나 평가하는데 몇초 지났느지 알려줌
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
EPOCHS = 5
optimizer = optim.Adam(model.parameters(), lr=1e-6)
criterion = nn.CrossEntropyLoss()
model = model.to(device)
criterion = criterion.to(device)
best_valid_loss = float('inf')
for epoch in range(EPOCHS):    
    start_time = time.monotonic()    
    train_loss, train_acc = train(model, train_dataloader, optimizer, criterion, device)
    valid_loss, valid_acc = evaluate(model, val_dataloader, criterion, device)
        
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'ResNet-model.pt')

    end_time = time.monotonic()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Valid. Loss: {valid_loss:.3f} |  Valid. Acc: {valid_acc*100:.2f}%')

In [53]:
EPOCHS = 5
model2 = ResNet(resnet50_config_full_pre_actiavation,output_dim=2).to(device)
optimizer = optim.Adam(model2.parameters(), lr=1e-6)
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)
best_valid_loss = float('inf')
for epoch in range(EPOCHS):    
    start_time = time.monotonic()    
    train_loss, train_acc = train(model2, train_dataloader, optimizer, criterion, device)
    valid_loss, valid_acc = evaluate(model2, val_dataloader, criterion, device)
        
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model2.state_dict(), 'ResNet-model.pt')

    end_time = time.monotonic()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Valid. Loss: {valid_loss:.3f} |  Valid. Acc: {valid_acc*100:.2f}%')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for (x, y) in tqdm(iterator):


  0%|          | 0/13 [00:00<?, ?it/s]

torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 128, 56, 56])
torch.Size([32, 128, 56, 56])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28,

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for (x, y) in tqdm(iterator):


  0%|          | 0/4 [00:00<?, ?it/s]

torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 128, 56, 56])
torch.Size([32, 128, 56, 56])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28,

  0%|          | 0/13 [00:00<?, ?it/s]

torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 128, 56, 56])
torch.Size([32, 128, 56, 56])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28,

  0%|          | 0/4 [00:00<?, ?it/s]

torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 128, 56, 56])
torch.Size([32, 128, 56, 56])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28,

  0%|          | 0/13 [00:00<?, ?it/s]

torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 128, 56, 56])
torch.Size([32, 128, 56, 56])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28,

  0%|          | 0/4 [00:00<?, ?it/s]

torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 128, 56, 56])
torch.Size([32, 128, 56, 56])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28,

  0%|          | 0/13 [00:00<?, ?it/s]

torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 128, 56, 56])
torch.Size([32, 128, 56, 56])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28,

  0%|          | 0/4 [00:00<?, ?it/s]

torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 128, 56, 56])
torch.Size([32, 128, 56, 56])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28,

  0%|          | 0/13 [00:00<?, ?it/s]

torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 128, 56, 56])
torch.Size([32, 128, 56, 56])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28,

  0%|          | 0/4 [00:00<?, ?it/s]

torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 64, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 256, 56, 56])
torch.Size([32, 128, 56, 56])
torch.Size([32, 128, 56, 56])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 512, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28, 28])
torch.Size([32, 128, 28,