In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image

In [None]:
# Assuming VisionMamba is already defined and available
class CustomBinaryDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

In [None]:
# Define the transformations (resize, to tensor, and normalize)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# Example: paths to your images and their corresponding binary labels
# Add your image paths here
image_paths = ["path/to/image1.jpg", "path/to/image2.jpg", ...]
labels = [0, 1, ...]  # Add your corresponding labels (0 or 1)

# Create the dataset and dataloader
dataset = CustomBinaryDataset(image_paths, labels, transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

In [None]:
# Initialize the model (make sure the architecture is defined and matches the loaded model)
model = VisionMamba(
    patch_size=16,
    stride=8,
    embed_dim=384,
    depth=24,
    rms_norm=True,
    residual_in_fp32=True,
    fused_add_norm=True,
    final_pool_type='mean',
    if_abs_pos_embed=True,
    if_rope=False,
    if_rope_residual=False,
    bimamba_type="v2",
    if_cls_token=True,
    if_devide_out=True,
    use_middle_cls_token=True,
    num_classes=1000,  # Original number of classes
    drop_rate=0.0,
    drop_path_rate=0.1,
    drop_block_rate=None,
    img_size=224,
)

# Modify the head for binary classification
model.head = nn.Linear(in_features=384, out_features=2)

# Load the pretrained model weights
checkpoint = torch.load("path/to/your/checkpoint.pth", map_location="cpu")
model.load_state_dict(checkpoint["model"])

# Set model to training mode
model.train()

In [None]:
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Fine-tuning loop
num_epochs = 10

In [None]:
for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(
        f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader)}")

# Save the fine-tuned model
torch.save(model.state_dict(), "fine_tuned_vision_mamba.pth")

# Set model to evaluation mode
model.eval()

In [None]:
# Example: Making a prediction
with torch.no_grad():
    test_image = Image.open("path/to/test/image.jpg").convert("RGB")
    test_image = transform(test_image).unsqueeze(0).to(device)
    prediction = model(test_image)
    predicted_class = torch.argmax(prediction, dim=1).item()
    print(f"Predicted class: {predicted_class}")