In [1]:
from torch import nn
import torch
from torch.utils.data import Dataset, ConcatDataset, DataLoader
from torchvision.io import read_image
from torchvision.transforms import Resize
import torchvision
import torchvision.models as models
import matplotlib.pyplot as plt
import torch.nn.functional as F
import os
import datetime

from models import CustomModel
from custom_dataset import FoodDataset
from utils import preprocess_image

Чтение данных в Dataset и Dataloader

In [2]:
food_data_path = "../nn2_data/training/food/"
food_data = [(1, read_image(food_data_path + filename, mode=torchvision.io.ImageReadMode.RGB).to(dtype=torch.float32)) for filename in os.listdir(food_data_path)]
food_data = list(map(lambda x: (x[0], preprocess_image(x[1])), food_data))
non_food_data_path = "../nn2_data/training/non_food/"
non_food_data = [(0, read_image(non_food_data_path + filename, mode=torchvision.io.ImageReadMode.RGB).to(dtype=torch.float32)) for filename in os.listdir(non_food_data_path)]
non_food_data = list(map(lambda x: (x[0], preprocess_image(x[1])), non_food_data))

food_dataset = FoodDataset(food_data)
non_food_dataset = FoodDataset(non_food_data)
train_dataset = ConcatDataset([food_dataset, non_food_dataset])
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)


food_data_path = "../nn2_data/validation/food/"
food_data = [(1, read_image(food_data_path + filename, mode=torchvision.io.ImageReadMode.RGB).to(dtype=torch.float32)) for filename in os.listdir(food_data_path)[:100]]
food_data = list(map(lambda x: (x[0], preprocess_image(x[1])), food_data))
non_food_data_path = "../nn2_data/validation/non_food/"
non_food_data = [(0, read_image(non_food_data_path + filename, mode=torchvision.io.ImageReadMode.RGB).to(dtype=torch.float32)) for filename in os.listdir(non_food_data_path)[:100]]
non_food_data = list(map(lambda x: (x[0], preprocess_image(x[1])), non_food_data))



food_dataset = FoodDataset(food_data)
non_food_dataset = FoodDataset(non_food_data)
validation_dataset = ConcatDataset([food_dataset, non_food_dataset])
validation_dataloader = DataLoader(validation_dataset, batch_size=8, shuffle=True)



In [3]:
device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))

In [4]:
def training_loop(n_epochs, optimizer, model, loss_fn, train_loader, validation_loader):
    prev_validation_loss = None
    for epoch in range(1, n_epochs + 1):
        loss_train = 0.0
        for imgs, labels in train_loader:
            imgs = imgs.to(device=device, dtype=torch.float32)
            labels = labels.to(device=device)
            outputs = model(imgs)
            loss = loss_fn(outputs, labels).float()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_train += loss.item()
            
        #Функция потерь на валидационных данных
        summary_loss = 0
        for img, label in validation_loader:
            label = label.to(device=device)
            out = model(img.to(device=device, dtype=torch.float32))
            summary_loss += loss_fn(out, label)

        validation_loss = summary_loss/len(validation_loader)
    
        if epoch == 1 or epoch % 5 == 0:
        
            print('{} Epoch {}, Training loss {}, Validation loss {}'.format(
            datetime.datetime.now(),
            epoch,
            loss_train / len(train_loader),
            validation_loss)
            )
            if prev_validation_loss is not None and prev_validation_loss <= validation_loss:
                print(f"Early stop on epoch {epoch}")
                break
            else:
                prev_validation_loss = validation_loss
    
    
        

        


In [5]:
model = CustomModel().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
training_loop(n_epochs=200, optimizer=optimizer, model=model, loss_fn=loss_fn, 
              train_loader=train_dataloader, validation_loader=validation_dataloader)

  out = F.softmax(self.fc2(out))


2024-03-15 15:24:00.419892 Epoch 1, Training loss 0.6942727770805359, Validation loss 0.6931477189064026
2024-03-15 15:24:16.838136 Epoch 5, Training loss 0.6901377231280009, Validation loss 0.6906041502952576
2024-03-15 15:24:38.079592 Epoch 10, Training loss 0.6850043797492981, Validation loss 0.6867667436599731
2024-03-15 15:24:59.686831 Epoch 15, Training loss 0.6751467156410217, Validation loss 0.6800328493118286
2024-03-15 15:25:22.021589 Epoch 20, Training loss 0.6597509387334188, Validation loss 0.671393632888794
2024-03-15 15:25:43.719264 Epoch 25, Training loss 0.64295742893219, Validation loss 0.6630881428718567
2024-03-15 15:26:05.536950 Epoch 30, Training loss 0.6253913781642914, Validation loss 0.6529484391212463
2024-03-15 15:26:27.753378 Epoch 35, Training loss 0.6037578084468842, Validation loss 0.6351766586303711
2024-03-15 15:26:50.420336 Epoch 40, Training loss 0.5787639768123627, Validation loss 0.6087254285812378
2024-03-15 15:27:12.930086 Epoch 45, Training loss 

In [6]:
torch.save(model.state_dict(), "custom_model.pt")

In [7]:
alexnet = models.alexnet(pretrained=True).to(device=device)
last_layer = alexnet.classifier[-1]
last_layer.out_features = 2
torch.nn.init.xavier_uniform_(last_layer.weight) 

optimizer = torch.optim.SGD(alexnet.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()
training_loop(n_epochs=200, optimizer=optimizer, model=alexnet, loss_fn=loss_fn, 
              train_loader=train_dataloader, validation_loader=validation_dataloader)
torch.save(alexnet.state_dict(), 'alexnet.pt')



2024-03-15 15:28:57.814443 Epoch 1, Training loss 0.48655205079416436, Validation loss 0.29318419098854065
2024-03-15 15:29:49.279523 Epoch 5, Training loss 0.17911907649288575, Validation loss 0.18835394084453583
2024-03-15 15:30:54.277482 Epoch 10, Training loss 0.10999122543198367, Validation loss 0.22242684662342072
Early stop on epoch 10


In [8]:
resnet18 = models.resnet18(pretrained=True)
resnet18.fc = torch.nn.Linear(resnet18.fc.in_features, 2)
torch.nn.init.xavier_uniform_(resnet18.fc.weight) 
resnet18 = resnet18.to(device=device)
optimizer = torch.optim.SGD(resnet18.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()
training_loop(n_epochs=200, optimizer=optimizer, model=resnet18, loss_fn=loss_fn, 
              train_loader=train_dataloader, validation_loader=validation_dataloader)
torch.save(resnet18.state_dict(), 'resnet.pt')



2024-03-15 15:31:19.505210 Epoch 1, Training loss 0.3063100711219013, Validation loss 0.28808778524398804
2024-03-15 15:37:20.436890 Epoch 5, Training loss 0.08730832822388038, Validation loss 0.20573806762695312
2024-03-15 15:45:55.415923 Epoch 10, Training loss 0.07073091881012078, Validation loss 0.23982344567775726
Early stop on epoch 10
