In [None]:
model_checkpoint = 'timm/maxvit_large_tf_384.in21k_ft_in1k' # pre-trained model from which to fine-tune
model_name = model_checkpoint.split("/")[-1]
save_name = f"{model_name}-finetuned"
batch_size = 32 # batch size for training and evaluation
num_epoch = 128

In [None]:
from datasets import load_dataset 

dataset = load_dataset("imagefolder", data_files={"train": "Dataset/train/**", "val": "Dataset/val/**"})

In [None]:
import timm

model = timm.create_model(model_checkpoint, pretrained=True, num_classes=len(dataset['train'].features['label'].names))

In [None]:
from torchvision import transforms
import torch

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    normalize,
])

dataset = dataset.with_transform(lambda examples: {'pixel_values': [transform(image.convert("RGB")) for image in examples['image']], 'label': examples['label']})


train_loader = torch.utils.data.DataLoader(dataset['train'], batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset['val'], batch_size=batch_size)

In [None]:
import torch.nn as nn
import torch.optim as optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
import numpy as np


class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = np.inf
        self.counter = 0

    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1

        if self.counter >= self.patience:
            print("Early stopping triggered")
            return True
        return False
    
early_stopping = EarlyStopping(patience=3, min_delta=0.01)

In [None]:
from tqdm import tqdm


for epoch in range(num_epoch):
    model.train()
    running_loss = 0.0
    total_batches = len(train_loader)  # Toplam batch sayısı

    # tqdm ile eğitim ilerlemesini göstermek için
    with tqdm(total=total_batches, desc=f'Epoch {epoch + 1}/{num_epoch}', unit='batch') as pbar:
        for i, inputs in enumerate(train_loader):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            optimizer.zero_grad()
            outputs = model(inputs['pixel_values'])
            loss = criterion(outputs, inputs['label'])
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            # Yüzde ilerlemesini güncelle
            pbar.set_postfix({'loss': running_loss / (i + 1)})
            pbar.update(1)

    print(f'Epoch {epoch + 1}, Training Loss: {running_loss / total_batches}')

    # Validation Değerlendirmesi
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs in val_loader:
            inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = model(inputs['pixel_values'])
            loss = criterion(outputs, inputs['label'])
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += inputs['label'].size(0)
            correct += (predicted == inputs['label']).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(inputs['label'].cpu().numpy())

    val_loss /= len(val_loader)
    print(f'Validation Loss: {val_loss}, Accuracy: {100 * correct / total}%')

    # Early Stopping Kontrolü
    if early_stopping(val_loss):
        break

In [None]:
torch.save(model.state_dict(), save_name)

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=dataset['train'].features['label'].names)
disp.plot(cmap=plt.cm.Blues)
plt.show()

In [None]:
from confusion import Confusion
performances, _  = Confusion.getValues(cm)
performances