In [None]:
import os
from torchvision import transforms
from PIL import Image
import random
from torch.utils.data import Dataset
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
import torch

In [7]:
# Define transforms for augmentation
fraud_augmentation = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor()
])

class FraudDatasetWithAugmentation(Dataset):
    def __init__(self, image_dir, processor, label_map, augment_fraud=False):
        self.image_paths = []
        self.labels = []
        self.processor = processor
        self.augment_fraud = augment_fraud

        for label_name in os.listdir(image_dir):
            class_dir = os.path.join(image_dir, label_name)
            for fname in os.listdir(class_dir):
                if fname.lower().endswith(('.jpg', '.jpeg', '.png')):
                    self.image_paths.append(os.path.join(class_dir, fname))
                    self.labels.append(label_map[label_name])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(path).convert("RGB")

        # Only augment Fraud images during training
        if self.augment_fraud and label == 1:
            image = fraud_augmentation(image)

            # Re-process to match expected ViT input
            processed = self.processor(images=image, return_tensors="pt", do_rescale=False)
        else:
            processed = self.processor(images=image, return_tensors="pt")

        item = {key: val.squeeze(0) for key, val in processed.items()}
        item["labels"] = label
        return item


In [9]:
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.array([0, 1]),
    y=[0]*4000 + [1]*160
)

weights = torch.tensor(class_weights, dtype=torch.float).to(device)
criterion = nn.CrossEntropyLoss(weight=weights)

NameError: name 'device' is not defined