In [None]:
!pip install -q transformers datasets

In [None]:
!pip install transformers[torch]
!pip install accelerate -U

In [None]:
import random
import torch
from torch.utils.data import DataLoader
from transformers import BeitForImageClassification, BeitImageProcessor
from datasets import load_dataset
from torch.optim import Adam

# Load the dataset
dataset_path = "/kaggle/input/aid-scene-classification-datasets"
dataset = load_dataset('imagefolder', data_dir=dataset_path)

# Split the dataset
splits = dataset['train'].train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']


In [None]:
id2label = {id:label for id, label in enumerate(train_ds.features['label'].names)}
print(id2label)

In [None]:
import random
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from transformers import BeitForImageClassification, BeitImageProcessor
from torch.optim import Adam

# Load the processor
processor = BeitImageProcessor.from_pretrained("microsoft/beit-base-patch16-224")

def preprocess_data(examples):
    images = [image.convert("RGB") for image in examples['image']]
    inputs = processor(images, return_tensors="pt")
    inputs['labels'] = torch.tensor(examples['label'], dtype=torch.long)
    return {'pixel_values': inputs['pixel_values'], 'labels': inputs['labels']}

train_ds = train_ds.with_transform(preprocess_data)
val_ds = val_ds.with_transform(preprocess_data)

# DataLoader
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([item['pixel_values'].squeeze(0) for item in batch]),
        'labels': torch.tensor([item['labels'] for item in batch], dtype=torch.long)
    }

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, collate_fn=collate_fn)

# Load model with ignore_mismatched_sizes
model = BeitForImageClassification.from_pretrained(
    "microsoft/beit-base-patch16-224", 
    num_labels=len(dataset['train'].features['label'].names),
    ignore_mismatched_sizes=True
)

# Optimizer
optimizer = Adam(model.parameters(), lr=5e-5)

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Function to plot the image
def plot_image(image, actual_label, predicted_label, brightness_factor=1.5):
    brightness_transform = transforms.ColorJitter(brightness=brightness_factor)
    bright_image = brightness_transform(image)

    plt.imshow(bright_image.permute(1, 2, 0))  # Convert from (C, H, W) to (H, W, C)
    plt.title(f"Actual: {id2label[actual_label]}, Predicted: {id2label[predicted_label]}")
    plt.axis('off')
    plt.show()

for epoch in range(2):  # Number of epochs
    model.train()
    for batch in train_loader:
        inputs = batch['pixel_values'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(pixel_values=inputs, labels=labels)
        loss = outputs.loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f"Epoch {epoch + 1} completed. Loss: {loss.item()}")

    # Validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in val_loader:
            inputs = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(pixel_values=inputs)
            _, predicted = torch.max(outputs.logits, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = correct / total
    print(f"Validation Accuracy: {accuracy:.2f}")

    # Print a random predicted and actual label
    random_index = random.randint(0, len(val_ds) - 1)
    random_sample = val_ds[random_index]  # Get the random sample
    random_image = random_sample['pixel_values']  # Use the preprocessed image
    random_label = random_sample['labels']  # Correctly access the label
    
    # Make prediction with the random sample
    with torch.no_grad():
        outputs = model(pixel_values=random_image.unsqueeze(0).to(device))  # Unsqueeze for batch dimension
        _, predicted_label = torch.max(outputs.logits, 1)

    # Plot image with predictions
    plot_image(random_image.cpu(), random_label.item(), predicted_label.item())

print("Training completed.")