In [5]:
from google.colab import drive
drive.mount('/content/drive')
DATA_ROOT = '/content/drive/MyDrive/Colab Notebooks/Voederhuiscamera.v2i.multiclass'


Mounted at /content/drive


In [6]:
!pip install timm safetensors tqdm torch torchvision




In [16]:
import os, csv, torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

class FeederDataset(Dataset):
    def __init__(self, root, classes):
        self.root = root
        self.classes = classes
        with open(os.path.join(root, '_classes.csv')) as f:
            reader = csv.reader(f)
            header = next(reader)
            self.samples = []
            for row in reader:
                filename = row[0]
                probs = list(map(float, row[1:]))
                label_idx = probs.index(1.0)
                self.samples.append((filename, label_idx))
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        filename, label = self.samples[idx]
        img = Image.open(os.path.join(self.root, filename)).convert('RGB')
        return self.transform(img), label

with open(os.path.join(DATA_ROOT, 'train', '_classes.csv')) as f:
    reader = csv.reader(f)
    header = next(reader)
    classes = header[1:]

train_ds = FeederDataset(os.path.join(DATA_ROOT, 'train'), classes)
valid_ds = FeederDataset(os.path.join(DATA_ROOT, 'valid'), classes)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=2)
valid_loader = DataLoader(valid_ds, batch_size=32, shuffle=False, num_workers=2)


In [21]:
import timm
import torch
from torch import nn
from tqdm import tqdm
import safetensors

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load EfficientViT-m0 pretrained on ImageNet-1K
model = timm.create_model('efficientvit_m0.r224_in1k', pretrained=True)
model.reset_classifier(len(classes))  # replace classifier with new size
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

best_acc = 0.0
best_state = None

def evaluate(loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            logits = model(imgs)
            preds = logits.argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.numel()
    return correct / total

for epoch in range(20):
    model.train()
    running_loss = 0.0
    for imgs, labels in tqdm(train_loader):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)

    train_loss = running_loss / len(train_loader.dataset)
    val_acc = evaluate(valid_loader)
    print(f"epoch {epoch+1}: train loss {train_loss:.4f}, valid acc {val_acc*100:.2f}%")

    if val_acc > best_acc:
        best_acc = val_acc
        best_state = {k: v.cpu() for k, v in model.state_dict().items()}
        print(f"New best checkpoint at epoch {epoch+1}: {best_acc*100:.2f}%")

# Save the best-performing EfficientViT weights
if best_state is not None:
    safetensors.torch.save_file(best_state, "/content//drive/MyDrive/Colab Notebooks/Voederhuiscamera.v2i.multiclass/feeder-efficientvit-m0.safetensors")
else:
    safetensors.torch.save_file(model.state_dict(), "/content//drive/MyDrive/Colab Notebooks/Voederhuiscamera.v2i.multiclass/feeder-efficientvit-m0.safetensors")



100%|██████████| 51/51 [00:14<00:00,  3.62it/s]


epoch 1: train loss 2.0495, valid acc 72.44%
New best checkpoint at epoch 1: 72.44%


100%|██████████| 51/51 [00:13<00:00,  3.72it/s]


epoch 2: train loss 0.8496, valid acc 80.77%
New best checkpoint at epoch 2: 80.77%


100%|██████████| 51/51 [00:13<00:00,  3.72it/s]


epoch 3: train loss 0.4531, valid acc 87.82%
New best checkpoint at epoch 3: 87.82%


100%|██████████| 51/51 [00:13<00:00,  3.74it/s]


epoch 4: train loss 0.2500, valid acc 88.46%
New best checkpoint at epoch 4: 88.46%


100%|██████████| 51/51 [00:13<00:00,  3.74it/s]


epoch 5: train loss 0.1343, valid acc 89.10%
New best checkpoint at epoch 5: 89.10%


100%|██████████| 51/51 [00:13<00:00,  3.72it/s]


epoch 6: train loss 0.0727, valid acc 90.38%
New best checkpoint at epoch 6: 90.38%


100%|██████████| 51/51 [00:13<00:00,  3.74it/s]


epoch 7: train loss 0.0529, valid acc 91.03%
New best checkpoint at epoch 7: 91.03%


100%|██████████| 51/51 [00:13<00:00,  3.69it/s]


epoch 8: train loss 0.0360, valid acc 91.67%
New best checkpoint at epoch 8: 91.67%


100%|██████████| 51/51 [00:13<00:00,  3.65it/s]


epoch 9: train loss 0.0271, valid acc 89.74%


100%|██████████| 51/51 [00:13<00:00,  3.65it/s]


epoch 10: train loss 0.0254, valid acc 92.95%
New best checkpoint at epoch 10: 92.95%


100%|██████████| 51/51 [00:13<00:00,  3.70it/s]


epoch 11: train loss 0.0184, valid acc 90.38%


100%|██████████| 51/51 [00:13<00:00,  3.69it/s]


epoch 12: train loss 0.0212, valid acc 92.31%


100%|██████████| 51/51 [00:14<00:00,  3.62it/s]


epoch 13: train loss 0.0138, valid acc 89.74%


100%|██████████| 51/51 [00:14<00:00,  3.55it/s]


epoch 14: train loss 0.0127, valid acc 91.03%


100%|██████████| 51/51 [00:14<00:00,  3.56it/s]


epoch 15: train loss 0.0186, valid acc 87.18%


100%|██████████| 51/51 [00:14<00:00,  3.55it/s]


epoch 16: train loss 0.0454, valid acc 87.82%


100%|██████████| 51/51 [00:14<00:00,  3.58it/s]


epoch 17: train loss 0.0266, valid acc 87.18%


100%|██████████| 51/51 [00:14<00:00,  3.63it/s]


epoch 18: train loss 0.0250, valid acc 91.67%


100%|██████████| 51/51 [00:14<00:00,  3.63it/s]


epoch 19: train loss 0.0278, valid acc 85.90%


100%|██████████| 51/51 [00:14<00:00,  3.53it/s]


epoch 20: train loss 0.0265, valid acc 89.10%


In [22]:
labels_path = '/content//drive/MyDrive/Colab Notebooks/Voederhuiscamera.v2i.multiclass/feeder-labels.csv'
with open(labels_path, 'w') as f:
    for cls in classes:
        f.write(cls + '\n')

print("Saved labels:", labels_path)


Saved labels: /content//drive/MyDrive/Colab Notebooks/Voederhuiscamera.v2i.multiclass/feeder-labels.csv
