# Swin Transformer

In [4]:
# uvolnenie GPU pamate
import torch

torch.cuda.empty_cache()
torch.cuda.ipc_collect()

In [1]:
!nvidia-smi

Sun Mar 30 11:04:23 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.230.02             Driver Version: 535.230.02   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Quadro P4000                   Off | 00000000:8B:00.0 Off |                  N/A |
| 39%   29C    P0              25W / 105W |      0MiB /  8192MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [1]:
!pip install torch torchvision transformers timm



In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import timm
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from transformers import get_cosine_schedule_with_warmup

# ✅ Dataset class
class MeteorDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.image_names = [fname for fname in os.listdir(image_dir) if fname.endswith(('.jpg', '.png'))]
        self.image_paths = [os.path.join(image_dir, fname) for fname in self.image_names]
        self.label_paths = [os.path.join(label_dir, fname.replace('.jpg', '.txt').replace('.png', '.txt')) for fname in self.image_names]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label_path = self.label_paths[idx]
        img = Image.open(img_path).convert("RGB")

        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                content = f.read().strip()
                label = 1 if '0' in content else 0
        else:
            label = 0

        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(label, dtype=torch.float32)

# ✅ Augmentations
image_size = 384  # reduce size to fit memory
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(image_size, scale=(0.7, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.2, scale=(0.02, 0.1), ratio=(0.3, 3.3)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

valid_transforms = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# ✅ Datasets
data_path = "/home/jovyan/data/lightning/LiviaMurankova/new/namnozene_meteory/DP_rozdelenie_dat/data_split_v2"
train_dataset = MeteorDataset(f"{data_path}/train/images", f"{data_path}/train/labels", transform=train_transforms)
valid_dataset = MeteorDataset(f"{data_path}/valid/images", f"{data_path}/valid/labels", transform=valid_transforms)

# ✅ Weighted sampler
label_list = [int(train_dataset.__getitem__(idx)[1].item()) for idx in range(len(train_dataset))]
num_non_meteor, num_meteor = label_list.count(0), label_list.count(1)
print(f"📊 Meteors: {num_meteor}, Non-Meteors: {num_non_meteor}")
weights = [1.0 / num_non_meteor if label == 0 else 1.0 / num_meteor for label in label_list]
sampler = WeightedRandomSampler(weights, num_samples=len(label_list), replacement=True)

batch_size = 2  # reduce batch size to save memory
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

# ✅ Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.75, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.bce = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, inputs, targets):
        targets = targets.unsqueeze(1)
        BCE_loss = self.bce(inputs, targets)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
        return F_loss.mean()

# ✅ Model setup: SwinV2 Small s globálnym poolingom pre klasifikáciu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model(
    "swinv2_small_window8_256",
    pretrained=True,
    num_classes=1,
    strict_img_size=False,
    global_pool='avg'
)
model = model.to(device)

# ✅ Training setup
criterion = FocalLoss(alpha=0.75, gamma=2)
optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=3e-4)

epochs = 30
num_training_steps = len(train_loader) * epochs
num_warmup_steps = len(train_loader) * 5
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)

# ✅ Training loop
best_val_loss = float('inf')
patience, trigger_times = 5, 0

torch.cuda.empty_cache()  # free memory before training

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        images, labels = images.to(device), labels.to(device).float()
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)

    # ✅ Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(device), labels.to(device).float()
            outputs = model(images)
            val_loss += criterion(outputs, labels).item()

    avg_val_loss = val_loss / len(valid_loader)
    print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f}, Val Loss={avg_val_loss:.4f}")

    # ✅ Early stopping
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), "datasetv2_best_swinv2_model.pth")
        print("✅ Model saved (best so far)")
        trigger_times = 0
    else:
        trigger_times += 1
        print(f"⚠️ Early stopping counter: {trigger_times}/{patience}")
        if trigger_times >= patience:
            print("⛔ Early stopping triggered")
            break

print("🎉 Training completed!")

📊 Meteors: 5191, Non-Meteors: 2735


model.safetensors:   0%|          | 0.00/201M [00:00<?, ?B/s]

Epoch 1/30: 100%|██████████| 3963/3963 [17:07<00:00,  3.86it/s]


Epoch 1: Train Loss=0.0916, Val Loss=0.0460
✅ Model saved (best so far)


Epoch 2/30: 100%|██████████| 3963/3963 [17:26<00:00,  3.79it/s]


Epoch 2: Train Loss=0.0406, Val Loss=0.0262
✅ Model saved (best so far)


Epoch 3/30: 100%|██████████| 3963/3963 [17:36<00:00,  3.75it/s]


Epoch 3: Train Loss=0.0300, Val Loss=0.0279
⚠️ Early stopping counter: 1/5


Epoch 4/30: 100%|██████████| 3963/3963 [17:48<00:00,  3.71it/s]


Epoch 4: Train Loss=0.0252, Val Loss=0.0248
✅ Model saved (best so far)


Epoch 5/30: 100%|██████████| 3963/3963 [18:23<00:00,  3.59it/s]


Epoch 5: Train Loss=0.0242, Val Loss=0.0220
✅ Model saved (best so far)


Epoch 6/30: 100%|██████████| 3963/3963 [17:28<00:00,  3.78it/s]


Epoch 6: Train Loss=0.0209, Val Loss=0.0163
✅ Model saved (best so far)


Epoch 7/30: 100%|██████████| 3963/3963 [17:14<00:00,  3.83it/s]


Epoch 7: Train Loss=0.0193, Val Loss=0.0181
⚠️ Early stopping counter: 1/5


Epoch 8/30: 100%|██████████| 3963/3963 [17:12<00:00,  3.84it/s]


Epoch 8: Train Loss=0.0172, Val Loss=0.0180
⚠️ Early stopping counter: 2/5


Epoch 9/30: 100%|██████████| 3963/3963 [17:51<00:00,  3.70it/s]


Epoch 9: Train Loss=0.0169, Val Loss=0.0137
✅ Model saved (best so far)


Epoch 10/30: 100%|██████████| 3963/3963 [18:22<00:00,  3.59it/s]


Epoch 10: Train Loss=0.0155, Val Loss=0.0139
⚠️ Early stopping counter: 1/5


Epoch 11/30: 100%|██████████| 3963/3963 [17:31<00:00,  3.77it/s]


Epoch 11: Train Loss=0.0142, Val Loss=0.0127
✅ Model saved (best so far)


Epoch 12/30:   5%|▌         | 215/3963 [00:56<16:28,  3.79it/s]

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import timm
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from transformers import get_cosine_schedule_with_warmup

# ✅ Dataset class
class MeteorDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.image_names = [fname for fname in os.listdir(image_dir) if fname.endswith(('.jpg', '.png'))]
        self.image_paths = [os.path.join(image_dir, fname) for fname in self.image_names]
        self.label_paths = [os.path.join(label_dir, fname.replace('.jpg', '.txt').replace('.png', '.txt')) for fname in self.image_names]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label_path = self.label_paths[idx]
        img = Image.open(img_path).convert("RGB")

        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                content = f.read().strip()
                label = 1 if '0' in content else 0
        else:
            label = 0

        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(label, dtype=torch.float32)

# ✅ Augmentations
image_size = 512  # ⬆️ zväčšený vstupný rozmer
train_transforms = transforms.Compose([
    transforms.Resize((image_size + 32, image_size + 32)),
    transforms.RandomCrop(image_size),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

valid_transforms = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# ✅ Datasets
data_path = "/home/jovyan/data/lightning/LiviaMurankova/new/namnozene_meteory/DP_rozdelenie_dat/data_split_v2"
train_dataset = MeteorDataset(f"{data_path}/train/images", f"{data_path}/train/labels", transform=train_transforms)
valid_dataset = MeteorDataset(f"{data_path}/valid/images", f"{data_path}/valid/labels", transform=valid_transforms)

# ✅ Weighted sampler
label_list = [int(train_dataset.__getitem__(idx)[1].item()) for idx in range(len(train_dataset))]
num_non_meteor, num_meteor = label_list.count(0), label_list.count(1)
print(f"📊 Meteors: {num_meteor}, Non-Meteors: {num_non_meteor}")
weights = [1.0 / num_non_meteor if label == 0 else 1.0 / num_meteor for label in label_list]
sampler = WeightedRandomSampler(weights, num_samples=len(label_list), replacement=True)

batch_size = 2  # menší batch kvôli pamäti
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

# ✅ Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.75, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.bce = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, inputs, targets):
        targets = targets.unsqueeze(1)
        BCE_loss = self.bce(inputs, targets)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
        return F_loss.mean()

# ✅ Model setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model(
    "swinv2_small_window8_256",
    pretrained=True,
    num_classes=1,
    strict_img_size=False,
    global_pool='avg'
)
model = model.to(device)

# ✅ Training setup
criterion = FocalLoss(alpha=0.75, gamma=2)
optimizer = optim.AdamW(model.parameters(), lr=3e-5, weight_decay=1e-4)  # upravený LR

epochs = 35
num_training_steps = len(train_loader) * epochs
num_warmup_steps = len(train_loader) * 3
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)

# ✅ Training loop
best_val_loss = float('inf')
patience, trigger_times = 5, 0

torch.cuda.empty_cache()

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        images, labels = images.to(device), labels.to(device).float()
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(device), labels.to(device).float()
            outputs = model(images)
            val_loss += criterion(outputs, labels).item()

    avg_val_loss = val_loss / len(valid_loader)
    print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f}, Val Loss={avg_val_loss:.4f}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), "datasetv2_best_swinv2_model_v2.pth")
        print("✅ Model saved (best so far)")
        trigger_times = 0
    else:
        trigger_times += 1
        print(f"⚠️ Early stopping counter: {trigger_times}/{patience}")
        if trigger_times >= patience:
            print("⛔ Early stopping triggered")
            break

print("🎉 Training completed!")

📊 Meteors: 5191, Non-Meteors: 2735


model.safetensors:   0%|          | 0.00/201M [00:00<?, ?B/s]

Epoch 1/35: 100%|██████████| 3963/3963 [25:46<00:00,  2.56it/s]


Epoch 1: Train Loss=0.0609, Val Loss=0.0244
✅ Model saved (best so far)


Epoch 2/35: 100%|██████████| 3963/3963 [26:05<00:00,  2.53it/s]


Epoch 2: Train Loss=0.0291, Val Loss=0.0236
✅ Model saved (best so far)


Epoch 3/35: 100%|██████████| 3963/3963 [25:58<00:00,  2.54it/s]


Epoch 3: Train Loss=0.0269, Val Loss=0.0254
⚠️ Early stopping counter: 1/5


Epoch 4/35: 100%|██████████| 3963/3963 [25:48<00:00,  2.56it/s]


Epoch 4: Train Loss=0.0260, Val Loss=0.0306
⚠️ Early stopping counter: 2/5


Epoch 5/35: 100%|██████████| 3963/3963 [25:45<00:00,  2.56it/s]


Epoch 5: Train Loss=0.0225, Val Loss=0.0181
✅ Model saved (best so far)


Epoch 6/35: 100%|██████████| 3963/3963 [25:43<00:00,  2.57it/s]


Epoch 6: Train Loss=0.0210, Val Loss=0.0150
✅ Model saved (best so far)


Epoch 7/35: 100%|██████████| 3963/3963 [25:41<00:00,  2.57it/s]


Epoch 7: Train Loss=0.0192, Val Loss=0.0271
⚠️ Early stopping counter: 1/5


Epoch 8/35: 100%|██████████| 3963/3963 [25:39<00:00,  2.57it/s]


Epoch 8: Train Loss=0.0161, Val Loss=0.0215
⚠️ Early stopping counter: 2/5


Epoch 9/35: 100%|██████████| 3963/3963 [25:39<00:00,  2.57it/s]


Epoch 9: Train Loss=0.0161, Val Loss=0.0209
⚠️ Early stopping counter: 3/5


Epoch 10/35: 100%|██████████| 3963/3963 [25:37<00:00,  2.58it/s]


Epoch 10: Train Loss=0.0141, Val Loss=0.0155
⚠️ Early stopping counter: 4/5


Epoch 11/35: 100%|██████████| 3963/3963 [25:37<00:00,  2.58it/s]


Epoch 11: Train Loss=0.0157, Val Loss=0.0113
✅ Model saved (best so far)


Epoch 12/35: 100%|██████████| 3963/3963 [25:37<00:00,  2.58it/s]


Epoch 12: Train Loss=0.0132, Val Loss=0.0107
✅ Model saved (best so far)


Epoch 13/35: 100%|██████████| 3963/3963 [25:39<00:00,  2.57it/s]


Epoch 13: Train Loss=0.0118, Val Loss=0.0201
⚠️ Early stopping counter: 1/5


Epoch 14/35: 100%|██████████| 3963/3963 [25:40<00:00,  2.57it/s]


Epoch 14: Train Loss=0.0122, Val Loss=0.0130
⚠️ Early stopping counter: 2/5


Epoch 15/35: 100%|██████████| 3963/3963 [25:40<00:00,  2.57it/s]


Epoch 15: Train Loss=0.0106, Val Loss=0.0105
✅ Model saved (best so far)


Epoch 16/35: 100%|██████████| 3963/3963 [25:41<00:00,  2.57it/s]


Epoch 16: Train Loss=0.0083, Val Loss=0.0172
⚠️ Early stopping counter: 1/5


Epoch 17/35: 100%|██████████| 3963/3963 [25:39<00:00,  2.57it/s]


Epoch 17: Train Loss=0.0088, Val Loss=0.0117
⚠️ Early stopping counter: 2/5


Epoch 18/35: 100%|██████████| 3963/3963 [25:40<00:00,  2.57it/s]


Epoch 18: Train Loss=0.0090, Val Loss=0.0078
✅ Model saved (best so far)


Epoch 19/35:   1%|          | 40/3963 [00:15<25:22,  2.58it/s]