# Cactus | Visual Transformer

## Environment

### Linformer

In [None]:
!pip -q install vit_pytorch linformer

### Libraries

In [None]:
from __future__ import print_function
import glob
from itertools import chain
import os
import random
import zipfile
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from linformer import Linformer
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
import torchvision.transforms as T
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import seaborn as sns
from vit_pytorch.efficient import ViT
from collections import Counter

In [None]:
print(f"Torch: {torch.__version__}")

### Seed

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed=42)

### Device

In [None]:
print("GPU is available:", torch.cuda.is_available())

device = 'cuda'

## Data

### Download data

In [None]:
!wget https://jivg.org/wp-content/uploads/2024/07/cactus_course_dataset.zip

In [None]:
!unzip cactus_course_dataset.zip

### Load data

In [None]:
base_dir = '/content/cactus_course_dataset'
train_dir = os.path.join(base_dir, 'training_set')
val_dir = os.path.join(base_dir, 'validation_set')

In [None]:
# List all image paths in the training set
train_list = glob.glob(os.path.join(train_dir, '**', '*.jpg'), recursive=True)
val_list = glob.glob(os.path.join(val_dir, '**', '*.jpg'), recursive=True)

print(f"Train data (Total images): {len(train_list)}")
print(f"Val data (Total images): {len(val_list)}")

In [None]:
# Extract labels from folder names
labels = [path.split('/')[-2] for path in train_list]
print(f"Labels: {labels}")

In [None]:
# Randomly select 9 images and plot them
random_idx = np.random.choice(len(train_list), size=9, replace=False)
print(f"Random indices: {random_idx}")
fig, ax = plt.subplots(3, 3, figsize=(10, 10))

for idx, ax in zip(random_idx, ax.ravel()):
    img = Image.open(train_list[idx])
    ax.set_title(labels[idx])
    ax.imshow(img)
    ax.axis('off')

plt.show()

### Split

In [None]:
train_list, test_list = train_test_split(train_list, test_size=0.2, random_state=42)

print(f"Train data (Total images): {len(train_list)}")
print(f"Val data (Total images): {len(val_list)}")
print(f"Test data (Total images): {len(test_list)}")


In [None]:
train_labels = [path.split('/')[-2] for path in train_list]
train_counts = Counter(train_labels)

val_labels = [path.split('/')[-2] for path in val_list]
val_counts = Counter(val_labels)

test_labels = [path.split('/')[-2] for path in test_list]
test_counts = Counter(test_labels)

data = {
    "Clase": ["Cactus", "No Cactus"],
    "Entrenamiento": [train_counts.get("cactus", 0), train_counts.get("no_cactus", 0)],
    "Validación": [val_counts.get("cactus", 0), val_counts.get("no_cactus", 0)],
    "Prueba": [test_counts.get("cactus", 0), test_counts.get("no_cactus", 0)]
}

set_classes = pd.DataFrame(data)

set_classes

### Image Augmentation

In [None]:
train_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ]
)

test_transforms = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ]
)

### Load Datasets

In [None]:
class CactusDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        img = Image.open(img_path)
        img_transformed = self.transform(img)

        label = img_path.split("/")[-2]
        label = 1 if label == "cactus" else 0

        return img_transformed, label

In [None]:
train_data = CactusDataset(train_list, transform=train_transforms)
val_data = CactusDataset(val_list, transform=val_transforms)
test_data = CactusDataset(test_list, transform=test_transforms)

In [None]:
batch_size = 64

train_loader = DataLoader(train_data, batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size, shuffle=False)

In [None]:
print(f"Train data | Total images: {len(train_data)} | Total batches: {len(train_loader)}")
print(f"Val data | Total images: {len(val_data)} | Total batches: {len(val_loader)}")
print(f"Test data | Total images: {len(test_data)} | Total batches: {len(test_loader)}")

### Verification of dataset is loaded correctly

In [None]:
def show_images(dataloader, title):
    data_iter = iter(dataloader)
    images, labels = next(data_iter)

    fig, axes = plt.subplots(1, 8, figsize=(20, 3))
    fig.suptitle(title, fontsize=16)

    for i in range(8):
        img = images[i]
        img = img.permute(1, 2, 0).numpy()  # Convert to HWC format for plotting

        # Display the image and label
        axes[i].imshow(img)
        axes[i].set_title(f"Label:[{labels[i]} | {'Cactus' if labels[i] == 1 else 'No Cactus'}]")
        axes[i].axis('off')

    plt.show()

# Each dataset
show_images(train_loader, "Training Set Samples")
show_images(val_loader, "Validation Set Samples")
show_images(test_loader, "Test Set Samples")


## Model | Visual Transformer

### Linformer

In [None]:
efficient_transformer = Linformer(
    dim=128,
    seq_len=49+1,  # 7x7 patches + 1 cls-token
    depth=12,
    heads=8,
    k=64
)

### Visual Transformer

In [None]:
model = ViT(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=2,
    transformer=efficient_transformer,
    channels=3,
).to(device)

### Training

Training settings | Loss function | Optimizer | Scheduler

In [None]:
# Training settings
epochs = 1
lr = 3e-1
gamma = 0.1

# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

Model training

In [None]:
# Lists to store metrics for plotting
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    # Training
    model.train()
    for data, label in tqdm(train_loader):
        data, label = data.to(device), label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    train_losses.append(epoch_loss.item())
    train_accuracies.append(epoch_accuracy.item())

    # Validation
    model.eval()
    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in val_loader:
            data, label = data.to(device), label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(val_loader)
            epoch_val_loss += val_loss / len(val_loader)

        val_losses.append(epoch_val_loss.item())
        val_accuracies.append(epoch_val_accuracy.item())

    print(
        f"Epoch: {epoch+1}, Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_accuracy:.4f}, "
        f"Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_accuracy:.4f}"
    )


Training and validation loss | Training and validation accuracy

In [None]:
# Gráfica de pérdida de entrenamiento y validación
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Pérdida de Entrenamiento')
plt.plot(val_losses, label='Pérdida de Validación')
plt.xlabel('Épocas')
plt.ylabel('Pérdida')
plt.title('Pérdida de Entrenamiento y Validación')
plt.legend()
plt.savefig("perdida_entrenamiento_validacion.png")
plt.show()

# Gráfica de exactitud de entrenamiento y validación
plt.figure(figsize=(10, 5))
plt.plot(train_accuracies, label='Exactitud de Entrenamiento')
plt.plot(val_accuracies, label='Exactitud de Validación')
plt.xlabel('Épocas')
plt.ylabel('Exactitud')
plt.title('Exactitud de Entrenamiento y Validación')
plt.legend()
plt.savefig("exactitud_entrenamiento_validacion.png")
plt.show()
