# Hybrid ViT–CNN Architecture for Lithium-Ion Battery Type Identification on the RecyBat24 Dataset

---

**Importing Required Libraires**

In [None]:
import os
import random
import json
import numpy as np

# Image processing libraries
from PIL import Image

# Data Visualization libraries
import matplotlib.pyplot as plt
import seaborn as sns


# Deep Learning libraries
import torch
import timm
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.models import resnet50, ResNet50_Weights

# Model Evaluation
from sklearn.metrics import classification_report, confusion_matrix

# Other libraries
import warnings
warnings.filterwarnings('ignore')

**Mounting google drive**

In [None]:
from google.colab import drive
drive.mount('/content/drive')
data_dir = '/content/drive/MyDrive/CVA/recybat24'

**Helper functions**

In [None]:
# Custom Dataset class
class RecyBatDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_dir = root_dir

        with open(os.path.join(root_dir, "annotations.json"), "r") as f:
            data = json.load(f)

        # Image ID → filename
        self.id_to_filename = {
            img["id"]: img["file_name"] for img in data["images"]
        }

        # Image ID → label ID
        self.id_to_label = {
            ann["image_id"]: ann["category_id"] for ann in data["annotations"]
        }

        # Label ID → class name
        self.label_map = {
            cat["id"]: cat["name"] for cat in data["categories"]
        }

        # Convert labels to 0-based indexing
        self.label_to_index = {
            label_id: idx for idx, label_id in enumerate(self.label_map.keys())
        }

        self.samples = []
        for img_id, file_name in self.id_to_filename.items():
            if img_id in self.id_to_label:
                self.samples.append((
                    os.path.join(self.image_dir, file_name),
                    self.label_to_index[self.id_to_label[img_id]]
                ))

        self.classes = list(self.label_map.values())

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
      img_path, label = self.samples[idx]
      if not os.path.exists(img_path):
          print("MISSING PATH:", img_path)
      image = Image.open(img_path).convert("RGB")
      if self.transform:
          image = self.transform(image)
      return image, label

In [None]:
# Image Denormalizer
def denormalize(img_tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img = img_tensor * std + mean
    return img.clamp(0, 1)

In [None]:
# Funtionc to train the model
def train_one_epoch(model, dataloader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in dataloader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = correct / total

    return epoch_loss, epoch_acc

In [None]:
# Function to evaluate model
def evaluate(model, dataloader, criterion=None):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    all_labels = []
    all_preds = []

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            outputs = model(images)

            if criterion is not None:
                loss = criterion(outputs, labels)
                running_loss += loss.item()

            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    acc = correct / total
    avg_loss = running_loss / len(dataloader) if criterion else None

    return avg_loss, acc, all_labels, all_preds

**Exploratory Data Analysis**

In [None]:
train_dataset_temp = RecyBatDataset(
    root_dir = os.path.join(data_dir, "train"),
    transform=None
)

In [None]:
# Visualizing Class distribution
labels = [label for _, label in train_dataset_temp.samples]

plt.figure(figsize=(10,5))
plt.hist(labels, bins=len(train_dataset_temp.classes))
plt.xticks(range(len(train_dataset_temp.classes)),
           train_dataset_temp.classes, rotation=45)
plt.title("Class Distribution - RecyBat24")
plt.show()

**Data Augmentation and Preprocessing**

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMAGE_SIZE = 224
NUM_WORKERS = 4
BATCH_SIZE = 32

In [None]:
train_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


test_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

**Dataset & Dataset loaders**

In [None]:
train_dataset = RecyBatDataset(
    root_dir= os.path.join(data_dir, 'train'),
    transform= train_transforms
)

test_dataset = RecyBatDataset(
    root_dir= os.path.join(data_dir, 'val'),
    transform=test_transforms
)

train_loader = DataLoader(
    train_dataset,
    batch_size = BATCH_SIZE,
    shuffle = True,
    num_workers = NUM_WORKERS
)

test_loader = DataLoader(
    test_dataset,
    batch_size = BATCH_SIZE,
    shuffle = False,
    num_workers = NUM_WORKERS
)

NUM_CLASSES = len(train_dataset.classes)

Hybrid CNN-VIT model definition

In [None]:
cnn_backbone = resnet50(weights=ResNet50_Weights.DEFAULT)

cnn_feature_dim = cnn_backbone.fc.in_features
cnn_backbone.fc = nn.Identity()

vit = timm.create_model(
    "vit_base_patch16_224",
    pretrained=True,
    num_classes=0
)

In [None]:
class HybridCNNViT(nn.Module):
    def __init__(self, cnn, vit, cnn_feature_dim, num_classes):
        super().__init__()
        self.cnn = cnn
        self.vit = vit

        self.fc = nn.Linear(
            cnn_feature_dim + vit.num_features,
            num_classes
        )

    def forward(self, x):
        cnn_feat = self.cnn(x)
        vit_feat = self.vit(x)
        combined = torch.cat((cnn_feat, vit_feat), dim=1)
        return self.fc(combined)

In [None]:
model = HybridCNNViT(
    cnn = cnn_backbone,
    vit = vit,
    cnn_feature_dim = cnn_feature_dim,
    num_classes = NUM_CLASSES
).to( DEVICE )

**Model training and Evaluation**

In [None]:
# Defining Hyperparameters
EPOCHS = 10
LR = 3e-4
WD = 1e-4

In [None]:
# Freeze CNN and ViT backbones
for param in model.cnn.parameters():
    param.requires_grad = False

for param in model.vit.parameters():
    param.requires_grad = False

In [None]:
criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(
    model.fc.parameters(),
    lr = LR,
    weight_decay = WD
)

In [None]:
for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(
        model, train_loader, optimizer, criterion
    )

    print(
        f"Epoch [{epoch+1}/{EPOCHS}] | "
        f"Train Loss: {train_loss:.4f} | "
        f"Train Acc: {train_acc:.4f}"
    )

In [None]:
test_loss, test_acc, y_true, y_pred = evaluate(
    model, test_loader, criterion
)

print(f"Test Accuracy: {test_acc:.4f}")

In [None]:
print(
    classification_report(
        y_true,
        y_pred,
        target_names=test_dataset.classes
    )
)

In [None]:
cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(8, 6))
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=test_dataset.classes,
    yticklabels=test_dataset.classes
)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.show()

**Making Predictions**

In [None]:
LB = test_dataset.classes

idx = random.randint(0, len(test_dataset) - 1)
image, true_label = test_dataset[idx]

In [None]:
model.eval()
input_image = image.unsqueeze(0).to(DEVICE)  # shape: [1, 3, 224, 224]

In [None]:
with torch.no_grad():
    outputs = model(input_image)
    probabilities = torch.softmax(outputs, dim=1)

In [None]:
predicted_class = torch.argmax(probabilities, dim=1).item()
confidence_scores = probabilities.squeeze().cpu().numpy()

In [None]:
img_vis = denormalize(image)
img_vis = img_vis.permute(1, 2, 0).numpy()

plt.figure(figsize=(4, 4))
plt.imshow(img_vis)
plt.axis("off")
plt.title(
    f"True: {LB[true_label]}\n"
    f"Predicted: {LB[predicted_class]}"
)
plt.show()

In [None]:
print("Confidence scores:")
for i, score in enumerate(confidence_scores):
    print(f"{labels[i]}: {score:.4f}")