In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import timm
import pandas as pd
import os
from tqdm import tqdm
from PIL import Image

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
class CAPTCHAModel(nn.Module):
  # Vgl. Kommentare bei Eval in Training py
    def __init__(self, num_classes):
        super(CAPTCHAModel, self).__init__()
        self.model = timm.create_model('efficientnet_b0', pretrained=True)
        self.model.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.model.classifier.in_features, num_classes)
        )

    def forward(self, x):
        return self.model(x)

# Testtransformationen
test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def load_model(model_path, device):
    # Modell laden
    model = CAPTCHAModel(num_classes=12).to(device)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state_dict']) # Lädt die gespeicherten Gewichte in das Modell
    classes = checkpoint['classes']
    return model, classes

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, test_dir, transform=None):
        self.test_dir = test_dir
        self.transform = transform
        # Sortierte Liste der Bildnamen
        self.images = sorted([f for f in os.listdir(test_dir) if f.endswith('.png')],
                           key=lambda x: int(x.split('.')[0]))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.test_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, img_name

def predict_test_set(model, test_loader, device, classes):
    model.eval()
    predictions = []
    filenames = []

    with torch.no_grad():
        for inputs, names in tqdm(test_loader, desc="Predicting"):
            inputs = inputs.to(device)
            outputs = model(inputs)
            # Softmax anwenden um Wahrscheinlichkeiten zu erhalten
            probs = torch.nn.functional.softmax(outputs, dim=1)

            predictions.extend(probs.cpu().numpy())
            filenames.extend(names)

    return filenames, predictions

def save_predictions(filenames, predictions, classes, output_path):
    # DataFrame erstellen
    df = pd.DataFrame(predictions, columns=classes)
    df.insert(0, 'ImageName', filenames)

    # CSV speichern
    df.to_csv(output_path, index=False)
    print(f"Predictions saved to {output_path}")

In [None]:
def main():
    # Pfade
    test_dir = '/content/drive/MyDrive/ML1/test_data/test'
    model_path = '/content/drive/MyDrive/ML1/Best_Model/best_model.pth'
    output_path = '/content/drive/MyDrive/ML1/predictions.csv'

    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Model laden
    model, classes = load_model(model_path, device)
    model.eval()

    # Testdatensatz
    test_dataset = TestDataset(test_dir, transform=test_transforms)
    test_loader = DataLoader(
        test_dataset,
        batch_size=64,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    # Vorhersagen
    print("Starting predictions...")
    filenames, predictions = predict_test_set(model, test_loader, device, classes)

    # Ergebnisse speichern
    save_predictions(filenames, predictions, classes, output_path)

In [None]:
if __name__ == '__main__':
    main()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

  checkpoint = torch.load(model_path)


Starting predictions...


Predicting: 100%|██████████| 137/137 [03:47<00:00,  1.66s/it]


Predictions saved to /content/drive/MyDrive/ML1/predictions.csv
