# ConvNeXt Small

In [3]:
!pip install torch torchvision transformers timm



In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import timm
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from transformers import get_cosine_schedule_with_warmup

# Dataset class
class MeteorDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform
        self.image_names = [fname for fname in os.listdir(image_dir) if fname.endswith(('.jpg', '.png'))]
        self.image_paths = [os.path.join(image_dir, fname) for fname in self.image_names]
        self.label_paths = [os.path.join(label_dir, fname.replace('.jpg', '.txt').replace('.png', '.txt')) for fname in self.image_names]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label_path = self.label_paths[idx]
        img = Image.open(img_path).convert("RGB")
        
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                content = f.read().strip()
                label = 1 if '0' in content else 0
        else:
            label = 0
        
        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(label, dtype=torch.float32)

# Augmentations
train_transforms = transforms.Compose([
    #transforms.RandomResizedCrop(720, scale=(0.7, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    #transforms.RandomErasing(p=0.2, scale=(0.02, 0.1), ratio=(0.3, 3.3)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

valid_transforms = transforms.Compose([
    transforms.Resize((720, 720)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Datasets
data_path = "/home/jovyan/data/lightning/LiviaMurankova/new/namnozene_meteory/DP_rozdelenie_dat/data_split_v2"
train_dataset = MeteorDataset(f"{data_path}/train/images", f"{data_path}/train/labels", transform=train_transforms)
valid_dataset = MeteorDataset(f"{data_path}/valid/images", f"{data_path}/valid/labels", transform=valid_transforms)

# Weighted sampler
label_list = [int(train_dataset.__getitem__(idx)[1].item()) for idx in range(len(train_dataset))]
num_non_meteor, num_meteor = label_list.count(0), label_list.count(1)
print(f"📊 Meteors: {num_meteor}, Non-Meteors: {num_non_meteor}")
weights = [1.0 / num_non_meteor if label == 0 else 1.0 / num_meteor for label in label_list]
sampler = WeightedRandomSampler(weights, num_samples=len(label_list), replacement=True)

batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

# Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.75, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.bce = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, inputs, targets):
        targets = targets.unsqueeze(1)
        BCE_loss = self.bce(inputs, targets)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
        return F_loss.mean()

# Model setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model("convnext_small", pretrained=True, num_classes=1).to(device)
model.classifier = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(model.num_features, 1)
)

# Training setup
criterion = FocalLoss(alpha=0.75, gamma=2)
optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=2e-4)

epochs = 30
num_training_steps = len(train_loader) * epochs
num_warmup_steps = len(train_loader) * 5  # 5 warmup epoch steps
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)

# Training loop
best_val_loss = float('inf')
patience, trigger_times = 5, 0

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        images, labels = images.to(device), labels.to(device).float()
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(device), labels.to(device).float()
            outputs = model(images)
            val_loss += criterion(outputs, labels).item()

    avg_val_loss = val_loss / len(valid_loader)

    print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f}, Val Loss={avg_val_loss:.4f}")

    # Early stopping
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), "datasetv2_best_convnext_model_v65.pth")
        print("✅ Model saved (best so far)")
        trigger_times = 0
    else:
        trigger_times += 1
        print(f"⚠️ Early stopping counter: {trigger_times}/{patience}")
        if trigger_times >= patience:
            print("⛔ Early stopping triggered")
            break

print("🎉 Training completed!")


📊 Meteors: 5191, Non-Meteors: 2735


Epoch 1/30: 100%|██████████| 991/991 [1:49:45<00:00,  6.65s/it]


Epoch 1: Train Loss=0.0889, Val Loss=0.0588
✅ Model saved (best so far)


Epoch 2/30: 100%|██████████| 991/991 [2:36:02<00:00,  9.45s/it]  


Epoch 2: Train Loss=0.0265, Val Loss=0.0126
✅ Model saved (best so far)


Epoch 3/30:  13%|█▎        | 132/991 [20:01<2:47:40, 11.71s/it]