1. Импорт библиотек и настройка устройства

In [None]:
import torch
import torchvision
from torchvision import transforms, models
from torchvision.models import ResNet18_Weights
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.optim as optim

In [None]:
# Проверка доступности CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

2. Подготовка данных и аугментация

Добавьте трансформации и разделите данные на тренировочный и валидационный наборы:

In [None]:
# Трансформации для тренировочных данных
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
# Трансформации для валидационных данных
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
# Загрузка датасета
dataset = torchvision.datasets.ImageFolder(
    root='../images',
    transform=train_transform
)

In [None]:
# Разделение на тренировочный и валидационный наборы (80/20)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

In [None]:
# Создание DataLoader
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

3. Загрузка предобученной модели и модификация

Используем ResNet18 и заменим последний слой:

In [None]:
# Загрузка предобученной модели
model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

# Замораживаем все слои, кроме последнего
for param in model.parameters():
    param.requires_grad = False

# Заменяем последний слой для задачи (2 класса)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2)  # 2 класса: good и defective
model = model.to(device)

4. Определение функции потерь и оптимизатора

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

5. Обучение модели

In [None]:
num_epochs = 15
best_accuracy = 0.0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    # Валидация
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    val_accuracy = 100 * correct/total
    
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {running_loss/len(train_loader):.4f}")
    print(f"Val Loss: {val_loss/len(val_loader):.4f}")
    print(f"Val Accuracy: {val_accuracy:.2f}%")
    
    # Сохраняем лучшую модель
    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        torch.save({
            'model_state_dict': model.state_dict(),
            'class_to_idx': dataset.class_to_idx
        }, 'best_tire_classifier.pt')
        print("New best model saved!")

6. Сохранение модели

In [None]:
# torch.save(model.state_dict(), 'tire_classifier.pt')
# print("Model saved!")

7. Загрузка модели для использования

In [None]:
# model = models.resnet18(pretrained=False)
# model.fc = nn.Linear(num_features, 2)
# model.load_state_dict(torch.load('tire_classifier.pt'))
# model.eval()