# DeiT

In [6]:
!nvidia-smi

Thu Mar 27 10:58:37 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.230.02             Driver Version: 535.230.02   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Quadro P4000                   Off | 00000000:8B:00.0 Off |                  N/A |
| 48%   35C    P8               4W / 105W |   8096MiB /  8192MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
!pip install torch torchvision transformers timm



In [5]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
torch.cuda.empty_cache()
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from transformers import get_cosine_schedule_with_warmup
import timm

# ✅ Dataset class
class MeteorDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.image_names = [fname for fname in os.listdir(image_dir) if fname.endswith(('.jpg', '.png'))]
        self.image_paths = [os.path.join(image_dir, fname) for fname in self.image_names]
        self.label_paths = [os.path.join(label_dir, fname.replace('.jpg', '.txt').replace('.png', '.txt')) for fname in self.image_names]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label_path = self.label_paths[idx]
        img = Image.open(img_path).convert("RGB")
        
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                content = f.read().strip()
                label = 1 if '0' in content else 0
        else:
            label = 0
        
        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(label, dtype=torch.float32)

# ✅ Transforms for ViT (smaller resolution due to memory)
image_size = 384  # ViT default input size

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(384, scale=(0.7, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

valid_transforms = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

# ✅ Paths & datasets
data_path = "/home/jovyan/data/lightning/LiviaMurankova/new/namnozene_meteory/DP_rozdelenie_dat/data_split_v2"
train_dataset = MeteorDataset(f"{data_path}/train/images", f"{data_path}/train/labels", transform=train_transforms)
valid_dataset = MeteorDataset(f"{data_path}/valid/images", f"{data_path}/valid/labels", transform=valid_transforms)

# ✅ Weighted sampler
label_list = [int(train_dataset.__getitem__(i)[1].item()) for i in range(len(train_dataset))]
num_non_meteor, num_meteor = label_list.count(0), label_list.count(1)
weights = [1.0 / num_non_meteor if label == 0 else 1.0 / num_meteor for label in label_list]
sampler = WeightedRandomSampler(weights, num_samples=len(label_list), replacement=True)

batch_size = 4  # smaller image = can use larger batch
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

# ✅ Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.75, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.bce = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, inputs, targets):
        targets = targets.unsqueeze(1)
        BCE_loss = self.bce(inputs, targets)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
        return F_loss.mean()

# ✅ Model setup: Vision Transformer (ViT)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model("deit_base_patch16_384", pretrained=True, num_classes=1)

# Replace head for binary classification
model.head = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(model.head.in_features, 1)
)

model = model.to(device) 

# ✅ Training setup
criterion = FocalLoss(alpha=0.75, gamma=2)
optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=2e-4)

epochs = 25
num_training_steps = len(train_loader) * epochs
num_warmup_steps = len(train_loader) * 3  # 3 warmup epochs

scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)

# ✅ Training loop
best_val_loss = float('inf')
patience, trigger_times = 5, 0

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        images = images.to(device)
        labels = labels.to(device).float()
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)

    # ✅ Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, labels in valid_loader:
            images = images.to(device)
            labels = labels.to(device).float()
            outputs = model(images)
            val_loss += criterion(outputs, labels).item()

    avg_val_loss = val_loss / len(valid_loader)
    print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f}, Val Loss={avg_val_loss:.4f}")

    # ✅ Early stopping
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), "datasetv2_best_deit_model_v2.pth")
        print("✅ Model saved (best so far)")
        trigger_times = 0
    else:
        trigger_times += 1
        print(f"⚠️ Early stopping counter: {trigger_times}/{patience}")
        if trigger_times >= patience:
            print("⛔ Early stopping triggered")
            break

print("🎉 Training completed!")

model.safetensors:   0%|          | 0.00/347M [00:00<?, ?B/s]

Epoch 1/25: 100%|██████████| 1982/1982 [22:44<00:00,  1.45it/s]


Epoch 1: Train Loss=0.0870, Val Loss=0.0699
✅ Model saved (best so far)


Epoch 2/25: 100%|██████████| 1982/1982 [27:36<00:00,  1.20it/s]


Epoch 2: Train Loss=0.0375, Val Loss=0.0254
✅ Model saved (best so far)


Epoch 3/25: 100%|██████████| 1982/1982 [26:29<00:00,  1.25it/s]


Epoch 3: Train Loss=0.0293, Val Loss=0.0252
✅ Model saved (best so far)


Epoch 4/25: 100%|██████████| 1982/1982 [24:49<00:00,  1.33it/s]


Epoch 4: Train Loss=0.0240, Val Loss=0.0238
✅ Model saved (best so far)


Epoch 5/25: 100%|██████████| 1982/1982 [24:19<00:00,  1.36it/s]


Epoch 5: Train Loss=0.0232, Val Loss=0.0172
✅ Model saved (best so far)


Epoch 6/25: 100%|██████████| 1982/1982 [25:16<00:00,  1.31it/s]


Epoch 6: Train Loss=0.0193, Val Loss=0.0161
✅ Model saved (best so far)


Epoch 7/25: 100%|██████████| 1982/1982 [25:11<00:00,  1.31it/s]


Epoch 7: Train Loss=0.0184, Val Loss=0.0136
✅ Model saved (best so far)


Epoch 8/25: 100%|██████████| 1982/1982 [23:34<00:00,  1.40it/s]


Epoch 8: Train Loss=0.0153, Val Loss=0.0171
⚠️ Early stopping counter: 1/5


Epoch 9/25: 100%|██████████| 1982/1982 [23:59<00:00,  1.38it/s]


Epoch 9: Train Loss=0.0168, Val Loss=0.0144
⚠️ Early stopping counter: 2/5


Epoch 10/25: 100%|██████████| 1982/1982 [22:42<00:00,  1.45it/s]


Epoch 10: Train Loss=0.0145, Val Loss=0.0161
⚠️ Early stopping counter: 3/5


Epoch 11/25: 100%|██████████| 1982/1982 [22:45<00:00,  1.45it/s]


Epoch 11: Train Loss=0.0151, Val Loss=0.0144
⚠️ Early stopping counter: 4/5


Epoch 12/25: 100%|██████████| 1982/1982 [22:47<00:00,  1.45it/s]


Epoch 12: Train Loss=0.0113, Val Loss=0.0173
⚠️ Early stopping counter: 5/5
⛔ Early stopping triggered
🎉 Training completed!
