In [51]:
from efficientnet_pytorch import EfficientNet

In [5]:
import torch
import torchvision

In [6]:
import glob

In [13]:
train_root_path = './data/cifar/train/*.png'
test_root_path = './data/cifar/test/*.png'

labels = []
with open('./data/cifar/labels.txt', 'r') as f:
    labels = f.readlines()
labels_ = [label.replace('\n','') for label in labels]

train_imgs = glob.glob(train_root_path)
test_imgs = glob.glob(test_root_path)

In [15]:
labels_idx_dict = {label:idx+1 for idx, label in enumerate(labels_)}

In [19]:
import cv2

In [20]:
cv2.imread(train_imgs[0])

array([[[ 63,  62,  59],
        [ 45,  46,  43],
        [ 43,  48,  50],
        ...,
        [108, 132, 158],
        [102, 125, 152],
        [103, 124, 148]],

       [[ 20,  20,  16],
        [  0,   0,   0],
        [  0,   8,  18],
        ...,
        [ 55,  88, 123],
        [ 50,  83, 119],
        [ 57,  87, 122]],

       [[ 21,  24,  25],
        [  0,   7,  16],
        [  8,  27,  49],
        ...,
        [ 50,  84, 118],
        [ 50,  84, 120],
        [ 42,  73, 109]],

       ...,

       [[ 96, 170, 208],
        [ 34, 153, 201],
        [ 26, 161, 198],
        ...,
        [ 70, 133, 160],
        [  7,  31,  56],
        [ 20,  34,  53]],

       [[ 96, 139, 180],
        [ 42, 123, 173],
        [ 30, 144, 186],
        ...,
        [ 94, 148, 184],
        [ 34,  62,  97],
        [ 34,  53,  83]],

       [[116, 144, 177],
        [ 94, 129, 168],
        [ 87, 142, 179],
        ...,
        [140, 184, 216],
        [ 84, 118, 151],
        [ 72,  92, 123]]

In [21]:
train_imgs[0]

'./data/cifar/train/0_frog.png'

In [25]:
label = train_imgs[0].split('/')[-1].split('_')[-1].replace('.png','')
label

'frog'

In [31]:
transforms_dict = {
    'train':torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) #(mean, std)
    ]),
    'test':torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((.5,.5,.5),(.5, .5, .5))
    ])
}

In [33]:
class CIFAR10_Dataset(torch.utils.data.Dataset):
    def __init__(self,img_path,labels_idx_dict,mode,transforms_dict=None):
        super(CIFAR10_Dataset, self).__init__
        self.img_path = img_path
        self.labels_idx_dict = labels_idx_dict
        self.mode = mode
        self.transforms_dict = transforms_dict
        
    def __getitem__(self, index):
        img_p = self.img_path[index]
        img = cv2.imread(img_p)
        label = img_p.split('/')[-1].split('_')[-1].replace('.png','')
        label = self.labels_idx_dict[label]
        
        if self.transforms_dict:
            img = self.transforms_dict[self.mode](img)
            
        return img, label
        
    def __len__(self):
        return len(self.img_path)        

In [44]:
train_ds = CIFAR10_Dataset(train_imgs, labels_idx_dict, 'train',transforms_dict=transforms_dict)
test_ds = CIFAR10_Dataset(train_imgs,labels_idx_dict, 'test',transforms_dict=transforms_dict)

In [45]:
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=8, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=8, shuffle=False)

In [50]:
net = EfficientNet.from_name('efficientnet-b0')

In [52]:
criterion = torch.nn.CrossEntropyLoss()

In [54]:
optimizer = torch.optim.SGD(net.parameters(),lr=1e-3,momentum=0.9)

In [55]:
epochs_num = 50

In [None]:
for epoch in range(epochs_num):
    epoch_loss = 0
    epoch_correct = 0
    
    for batch in train_dl:
        imgs,labels = batch
        
        optimizer.zero_grad()
        output = net(imgs)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        
        
        