In [179]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision as tv

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import cv2

In [180]:
class Dataset2class(torch.utils.data.Dataset):
    def __init__(self, path_dir1:str, path_dir2:str):
        super().__init__()

        self.path_dir1 = path_dir1
        self.path_dir2 = path_dir2
       
        self.list_dir1 = sorted(os.listdir(path_dir1))
        self.list_dir2 = sorted(os.listdir(path_dir2))
        
    def __len__(self):
        return len(self.list_dir1) + len(self.list_dir2)
        
    def __getitem__(self, index):
        if index < len(self.list_dir1):
            class_id = 0
            img_path = os.path.join(self.path_dir1, self.list_dir1[index])
        else: 
            class_id = 1
            index -= len(self.list_dir1)
            img_path = os.path.join(self.path_dir2, self.list_dir2[index])

        img = cv2.imread(img_path, cv2.IMREAD_COLOR) 
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = img.astype(np.float32)
        img = img/255.0

        img = cv2.resize(img, (64, 64), interpolation=cv2.INTER_AREA)

        img = img.transpose((2, 0, 1))

        t_img = torch.from_numpy(img)

        t_class_id = torch.tensor(class_id)

        return {'img': t_img, 'label': t_class_id}   

In [181]:
train_ice_cream = '/home/sasha/Рабочий стол/data/food-101-tiny/train/ice_cream'
train_french_toast = '/home/sasha/Рабочий стол/data/food-101-tiny/train/french_toast'
test_ice_cream = '/home/sasha/Рабочий стол/data/food-101-tiny/valid/ice_cream'
test_french_toast = '/home/sasha/Рабочий стол/data/food-101-tiny/valid/french_toast'

train_ds_ic_ft = Dataset2class(train_ice_cream, train_french_toast)
test_ds_ic_ft = Dataset2class(test_ice_cream, test_french_toast)

In [182]:
batch_size = 16
train_dataloader = torch.utils.data.DataLoader(train_ds_ic_ft, shuffle=True, batch_size=batch_size, num_workers=1, drop_last=True)

test_dataloader = torch.utils.data.DataLoader(train_ds_ic_ft, shuffle=True, batch_size=batch_size, num_workers=1, drop_last=False)

In [195]:
class ConvolNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.act = nn.LeakyReLU()
        self.maxpool = nn.MaxPool2d(2, 2)
        self.conv0 = nn.Conv2d(3, 32, 3, stride=1, padding=0)
        self.conv1 = nn.Conv2d(32, 32, 3, stride=1, padding=0)
        self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=0)
        self.conv3 = nn.Conv2d(64, 128, 3, stride=1, padding=0)

        self.adaptive = nn.AdaptiveAvgPool2d((1,1))
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(128, 10)
        self.linear2 = nn.Linear(10, 2)
        
    def forward(self, x):

        out = self.conv0(x)
        out = self.act(out)
        out = self.maxpool(out)
        
        out = self.conv1(out)
        out = self.act(out)
        out = self.maxpool(out)
        
        out = self.conv2(out)
        out = self.act(out)
        out = self.maxpool(out)

        out = self.conv3(out)
        out = self.act(out)

        out = self.adaptive(out)
        out = self.flatten(out)
        out = self.linear1(out)
        out = self.act(out)
        out = self.linear2(out)

        return out
        

In [196]:
net = ConvolNet()

In [197]:
for sample in train_dataloader:
    img = sample['img']
    label = sample['label']
    net(img)
    break

In [204]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr = 1e-3)

In [205]:
def accuracy(pred, label):
    answer = F.softmax(pred.detach(), dim=1).numpy().argmax(1) == label.numpy().argmax(1)
    return answer.mean()

In [209]:
epochs = 10 
for epoch in range(epochs):
    loss_val = 0
    acc_val = 0 
    for sample in (pbar := tqdm(train_dataloader)):
        img, label = sample['img'], sample['label']
        optimizer.zero_grad()
        label = F.one_hot(label, 2).float()
        pred = net(img)
        loss = loss_fn(pred, label)

        loss.backward()
        loss_item = loss.item()
        loss_val += loss_item

        optimizer.step()

        acc_current = accuracy(pred, label)
        acc_val += acc_current

    pbar.set_description(f'loss: {loss_item:.5f}\taccuracy: {acc_current:.3f}')
    print(loss_val/len(train_dataloader))
    print(acc_val/len(train_dataloader))


100%|███████████████████████████████████████████| 18/18 [00:01<00:00, 14.19it/s]


0.48090952303674483
0.7916666666666666


100%|███████████████████████████████████████████| 18/18 [00:01<00:00, 14.44it/s]


0.4779139641258452
0.7951388888888888


100%|███████████████████████████████████████████| 18/18 [00:01<00:00, 13.74it/s]


0.4768507132927577
0.78125


100%|███████████████████████████████████████████| 18/18 [00:01<00:00, 12.32it/s]


0.4750903331571155
0.7777777777777778


100%|███████████████████████████████████████████| 18/18 [00:01<00:00, 13.23it/s]


0.49060459103849197
0.7916666666666666


100%|███████████████████████████████████████████| 18/18 [00:01<00:00, 13.18it/s]


0.4636906103955375
0.7881944444444444


100%|███████████████████████████████████████████| 18/18 [00:01<00:00, 13.35it/s]


0.4895409312513139
0.7847222222222222


100%|███████████████████████████████████████████| 18/18 [00:01<00:00, 13.85it/s]


0.43866215149561566
0.8020833333333334


100%|███████████████████████████████████████████| 18/18 [00:01<00:00, 13.64it/s]


0.39110497136910755
0.8298611111111112


100%|███████████████████████████████████████████| 18/18 [00:01<00:00, 12.92it/s]

0.43657660318745506
0.7916666666666666





In [None]:
loss_val = 0
acc_val = 0 
for sample in (pbar := tqdm(test_dataloader)):
    with torch.no_grad():
        img, label = sample['img'], sample['label']
        
        label = F.one_hot(label, 2).float()
        pred = net(img)
        
        loss = loss_fn(pred, label)
        loss_item = loss.item()
        loss_val += loss_item
        
        acc_current = accuracy(pred, label)
        acc_val += acc_current
    
    pbar.set_description(f'loss: {loss_item:.5f}\taccuracy: {acc_current:.3f}')
print(loss_val/len(test_dataloader))
print(acc_val/len(test_dataloader))
