In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from datasets import DatasetDict
from datasets import DatasetDict, Dataset, Features, ClassLabel, Value, Array3D
from transformers import Trainer, TrainingArguments, AutoModelForImageClassification, AutoFeatureExtractor
import numpy as np
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_image_files(folder):
    files = []
    for root, _, filenames in os.walk(folder):
        for filename in filenames:
            if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
                files.append(os.path.join(root, filename))
    return files


def load_images(folder):
    class_names = sorted(os.listdir(folder))
    data = []
    for class_name in class_names:
        class_folder = os.path.join(folder, class_name)
        if os.path.isdir(class_folder):
            class_index = class_names.index(class_name)
            image_files = get_image_files(class_folder)
            for image_file in image_files:
                data.append({
                    'image': image_file,
                    'label': class_index,
                })
    return data, class_names


def generate_dataset(folders):
    all_data = []
    for folder in folders:
        data, class_names = load_images(folder)
        all_data.extend(data)
    features = Features({
        # 'image': Value('string'),
        'image': Array3D(dtype="uint8", shape=(224, 224, 3)),
        'label': ClassLabel(names=class_names)
    })

    data_dict = {
        'image': [],
        'label': []
    }

    for item in all_data:
        try:
            image = Image.open(item['image']).convert(
                'RGB')  # Load and convert to RGB
            # Resize to fit model input size (224x224)
            image = image.resize((224, 224))
            image_np = np.array(image)  # Convert PIL Image to numpy array
            data_dict['image'].append(image_np)
            data_dict['label'].append(item['label'])
        except Exception as e:
            print(f"Error processing {item['image']}: {e}")

    return Dataset.from_dict(data_dict, features=features), class_names

In [None]:
# Define your dataset directories
train_folder = '/home/eh_abdol/fine_tune/gold/train'
validation_folder = '/home/eh_abdol/fine_tune/gold/validation'
test_folder = '/home/eh_abdol/fine_tune/gold/test'

# Generate the dataset
dataset, class_names = generate_dataset(
    [train_folder, validation_folder, test_folder])

# Split the dataset into train, validation, and test
train_data = dataset.filter(lambda x: x['image'].startswith(train_folder))
validation_data = dataset.filter(
    lambda x: x['image'].startswith(validation_folder))
test_data = dataset.filter(lambda x: x['image'].startswith(test_folder))

# Combine into a DatasetDict
dataset_dict = DatasetDict({
    'train': train_data,
    'validation': validation_data,
    'test': test_data
})

# Save the combined dataset if you want to reuse it later
dataset_dict.save_to_disk('/home/eh_abdol/fine_tune/gold_dataset')

In [2]:
from datasets import load_from_disk

# Load the dataset
dataset_dict = DatasetDict.load_from_disk(
    '/home/eh_abdol/fine_tune/gold_dataset')
dataset_dict

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 0
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 0
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 0
    })
})

In [None]:
# Initialize the model (make sure the architecture is defined and matches the loaded model)
model = VisionMamba(
    patch_size=16,
    stride=8,
    embed_dim=384,
    depth=24,
    rms_norm=True,
    residual_in_fp32=True,
    fused_add_norm=True,
    final_pool_type='mean',
    if_abs_pos_embed=True,
    if_rope=False,
    if_rope_residual=False,
    bimamba_type="v2",
    if_cls_token=True,
    if_devide_out=True,
    use_middle_cls_token=True,
    num_classes=1000,  # Original number of classes
    drop_rate=0.0,
    drop_path_rate=0.1,
    drop_block_rate=None,
    img_size=224,
)

# Modify the head for binary classification
model.head = nn.Linear(in_features=384, out_features=2)

# Load the pretrained model weights
checkpoint = torch.load("path/to/your/checkpoint.pth", map_location="cpu")
model.load_state_dict(checkpoint["model"])

# Set model to training mode
model.train()

In [None]:
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Fine-tuning loop
num_epochs = 10

In [None]:
for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(
        f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader)}")

# Save the fine-tuned model
torch.save(model.state_dict(), "fine_tuned_vision_mamba.pth")

# Set model to evaluation mode
model.eval()

In [None]:
# Example: Making a prediction
with torch.no_grad():
    test_image = Image.open("path/to/test/image.jpg").convert("RGB")
    test_image = transform(test_image).unsqueeze(0).to(device)
    prediction = model(test_image)
    predicted_class = torch.argmax(prediction, dim=1).item()
    print(f"Predicted class: {predicted_class}")