In [1]:
from torchvision import transforms
from torch.utils.data import Dataset, random_split, DataLoader
import torch
import torch.nn as nn # layer들을 호출하기 위해서
import numpy as np
import torch.optim as optim # optimization method를 사용하기 위해서
import torch.nn.init as init # weight initialization 해주기 위해서

import torchvision
import torch.nn.functional as F
import torchvision.models as models
from parallel import DataParallelModel,DataParallelCriterion

# %matplotlib inline

from efficientnet_pytorch import EfficientNet


# Accuracy Calculates
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

class ModelBase(nn.Module):
    # training step
    def training_step(self, batch):
        img, targets = batch
        out = self(img)
        loss = F.nll_loss(out, targets)
        return loss

    # validation step
    def validation_step(self, batch):
        img, targets = batch
        out = self(img)
        loss = F.nll_loss(out, targets)
        acc = accuracy(out, targets)
        return {'val_acc':acc.detach(), 'val_loss':loss.detach()}

    # validation epoch end
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()
        return {'val_loss':epoch_loss.item(), 'val_acc':epoch_acc.item()}

    # print result end epoch
    def epoch_end(self, epoch, result):
        print("Epoch [{}] : train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result["train_loss"], result["val_loss"], result["val_acc"]))




class PretrainedEfficientNet_V2(ModelBase):
    def __init__(self,n_class):
        super().__init__()

        self.network = EfficientNet.from_pretrained('efficientnet-b4')
        num_ftrs = self.network._fc.in_features
        self.network._fc = nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
#             nn.Linear(512, 256),
#             nn.ReLU(),
#             nn.Dropout(0.5),
            nn.Linear(512, n_class),
            nn.LogSoftmax(dim=1)
        )
#         self.network = nn.DataParallel(self.network,device_ids=[1,2,3,4])
#         self.network = DataParallelModel(self.network)

    def forward(self, xb):
        return self.network(xb)
        
    
    
class DataParallel(PretrainedEfficientNet_V2):
    def __init__(self,n_class):
        super().__init__(n_class)
        
        self.module = self.network
        
        self.network = nn.DataParallel(self.network,device_ids=[0,1,2,3,4,5,6,7])
    
    
    def getModule(self):
        return self.module
        
    
# 이미지 폴더로부터 데이터를 로드합니다.
class Dataset(Dataset):
    
    def __init__(self, ds, transform=None):
        self.ds = ds
        self.transform = transform
        
    def __len__(self):
        return len(self.ds)
    
    def __getitem__(self, idx):
        img, label = self.ds[idx]
        if self.transform:
            img = self.transform(img)  
            return img, label
        
        
def to_device(data, device):
    if isinstance(data, (list, tuple)):
        return [to_device(d, device) for d in data]
    else:
        return data.to(device, non_blocking=True)
        
        
class DeviceDataLoader:
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
    
    def __len__(self):
        return len(self.dl)
    
    def __iter__(self):
        for batch in self.dl:
            yield to_device(batch, self.device)
            
            
@torch.no_grad()
def evaluate(model, val_loader):
    model.eval()
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)