# Подготовка, импорты библиотек

In [1]:
import os
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt


import torchvision
import torch.utils.data as data
import torchvision.models as models
import torchvision.transforms as transforms

import PIL
from PIL import Image
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

import warnings
warnings.filterwarnings('ignore')

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [3]:
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')

# Сбор датасетов для обучения, валидации и теста

In [18]:
#указываем путь к данным
DATA_PATH = "./data/"
TRAIN_ANN_PATH = DATA_PATH + 'train.csv'

#читаем аннотацию
train_df = pd.read_csv(TRAIN_ANN_PATH)
print(train_df.head(2))

     filename  class_number
0  000000.png            18
1  000001.png            18


In [19]:
class RoadSignDataset(data.Dataset):
  """Road Signs dataset class.

    Arguments:
        root (str): path to images
        imlist - pandas DataFrame with columns file_name, class
        transform - torchvision transform applied to every image
    """
  def __init__(self, root, flist, transform=None):
        self.root   = root
        self.imlist = flist 
        self.transform = transform

  def __getitem__(self, index):
        #берем строку из пришедшего df по index
        impath, target = self.imlist.loc[index]

        #собираем полное имя картинки
        full_imname = os.path.join(self.root, impath)

        if not os.path.exists(full_imname):
            print('No file ', full_imname)
            pass

        img = Image.open(full_imname).convert('RGB')

        #применяем к изображению выбранное преобразование (аугментацию)
        img = self.transform(img)

        #на выход отдаём img, target - нужны для обучения и валидации
        return img, target, impath

  #метод возвращает длину датасета - просто как длину подаваемого dataframe
  def __len__(self):
        return len(self.imlist)

In [20]:
class RoadSignTestDataset(data.Dataset):
  """Road Signs Test dataset class.

    Arguments:
        root (str): path to images
        imlist - list of file_name
        transform - torchvision transform applied to every image
    """
  def __init__(self, root, flist=None, transform=None):
        self.root   = root
        
        if flist is not None:
            self.imlist = flist
        else:
            self.imlist = []
            for filename in os.listdir(self.root):
                if filename[filename.rfind(".") + 1:] in ['jpg', 'jpeg', 'png']:
                    self.imlist.append(filename)
        
        self.transform = transform

  def __getitem__(self, index):

        impath = self.imlist[index]

        #собираем полное имя картинки
        full_imname = os.path.join(self.root, impath)

        if not os.path.exists(full_imname): #если нет такой, ругаемся
            print('No file ', full_imname)
            pass

        #Сразу используем PIL тк torchvision.transforms работает с PIL Image (https://pytorch.org/docs/stable/torchvision/transforms.html)
        img = Image.open(full_imname).convert('RGB')

        #применяем к изображению выбранное преобразование (аугментацию)
        img = self.transform(img)

        #на выход отдаём img - нужны для обучения и валидации
        return img, impath

  #метод возвращает длину датасета - просто как длину подаваемого dataframe
  def __len__(self):
        return len(self.imlist)

In [21]:
#преобразования для train и val
transform_for_train_and_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

#преобразования для test, для старта те же
transform_for_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

#датафреймы
train, val = train_test_split(train_df, test_size=0.1, random_state=42)

#нам нужно будет обращаться по индексу, так что делаем reset
train.reset_index(inplace=True, drop=True)
val.reset_index(inplace=True, drop=True)

batch_size=64

In [22]:
trainset = RoadSignDataset(root='./data/train', flist=train, transform=transform_for_train_and_val)
valset = RoadSignDataset(root='./data/train', flist=val, transform=transform_for_train_and_val)
testset = RoadSignTestDataset(root='./data/test', flist=None, transform=transform_for_train_and_val)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, pin_memory=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, pin_memory=True)