In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import timm
from PIL import Image


In [2]:
class PlayingCardDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data = ImageFolder(data_dir, transform=transform)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

    @property
    def classes(self):
        return self.data.classes


In [3]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])


In [4]:
class SimpleCardClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.base = timm.create_model(
            'efficientnet_b0',
            pretrained=True,
            num_classes=num_classes
        )

    def forward(self, x):
        return self.base(x)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dir = "/kaggle/input/cards-image-datasetclassification/train"
val_dir   = "/kaggle/input/cards-image-datasetclassification/valid"

train_dataset = PlayingCardDataset(train_dir, transform)
val_dataset   = PlayingCardDataset(val_dir, transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False)

model = SimpleCardClassifier(num_classes=len(train_dataset.classes))
model.to(device)

Downloading model.safetensors:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

SimpleCardClassifier(
  (base): EfficientNet(
    (conv_stem): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNormAct2d(
      32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      (drop): Identity()
      (act): SiLU(inplace=True)
    )
    (blocks): Sequential(
      (0): Sequential(
        (0): DepthwiseSeparableConv(
          (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn1): BatchNormAct2d(
            32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): SiLU(inplace=True)
          )
          (se): SqueezeExcite(
            (conv_reduce): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (act1): SiLU(inplace=True)
            (conv_expand): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (gate): Sigmoid()
          )
          (conv_pw): Conv2d(32, 16, kernel_s

In [6]:


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

epochs = 5

for epoch in range(epochs):
    # -------- Train --------
    model.train()
    train_loss, train_correct = 0, 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)

        _, preds = torch.max(outputs, 1)
        train_correct += torch.sum(preds == labels)

        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)

    train_acc = train_correct.double() / len(train_dataset)

    # -------- Validation --------
    model.eval()
    val_loss, val_correct = 0, 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            _, preds = torch.max(outputs, 1)
            val_correct += torch.sum(preds == labels)
            val_loss += loss.item() * images.size(0)

    val_acc = val_correct.double() / len(val_dataset)

    print(f"Epoch {epoch+1}/{epochs}")
    print(f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")
    print("-" * 30)


Epoch 1/5
Train Acc: 0.5374 | Val Acc: 0.8604
------------------------------
Epoch 2/5
Train Acc: 0.8351 | Val Acc: 0.8906
------------------------------
Epoch 3/5
Train Acc: 0.9054 | Val Acc: 0.9283
------------------------------
Epoch 4/5
Train Acc: 0.9218 | Val Acc: 0.9321
------------------------------
Epoch 5/5
Train Acc: 0.9525 | Val Acc: 0.9321
------------------------------


In [7]:
torch.save(model.state_dict(), "card_model.pth")
print("Model saved successfully")


Model saved successfully


In [8]:
model = SimpleCardClassifier(num_classes=53)
model.load_state_dict(torch.load("card_model.pth", map_location=device))
model.to(device)
model.eval()


SimpleCardClassifier(
  (base): EfficientNet(
    (conv_stem): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNormAct2d(
      32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      (drop): Identity()
      (act): SiLU(inplace=True)
    )
    (blocks): Sequential(
      (0): Sequential(
        (0): DepthwiseSeparableConv(
          (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn1): BatchNormAct2d(
            32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): SiLU(inplace=True)
          )
          (se): SqueezeExcite(
            (conv_reduce): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (act1): SiLU(inplace=True)
            (conv_expand): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (gate): Sigmoid()
          )
          (conv_pw): Conv2d(32, 16, kernel_s

In [9]:
def predict_image(image_path, model, transform, class_names):
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(image)
        _, pred = torch.max(outputs, 1)

    return class_names[pred.item()]


In [None]:
class_names = train_dataset.classes

test_image = "j.jpg"
prediction = predict_image(test_image, model, transform, class_names)

print("Predicted card:", prediction)


Predicted card: joker
