In [4]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from transformers import SamModel, SamConfig, AdamW

# Dataset class
class SAMDataset(Dataset):
    def __init__(self, images, masks, transform=None):
        self.images = images
        self.masks = masks
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]) if isinstance(self.images[idx], str) else self.images[idx]
        mask = Image.open(self.masks[idx]) if isinstance(self.masks[idx], str) else self.masks[idx]
        
        if self.transform:
            image = self.transform(image)
        
        mask = np.array(mask)
        mask = torch.from_numpy(mask).long()

        return image, mask

# Define transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# Assume image_paths and mask_paths are lists of file paths
image_paths = ['path_to_image1.jpg', 'path_to_image2.jpg']
mask_paths = ['path_to_mask1.png', 'path_to_mask2.png']
dataset = SAMDataset(image_paths, mask_paths, transform=transform)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# Model configuration
config = SamConfig.from_pretrained("facebook/sam-vit-base", num_labels=7)
model = SamModel(config)
model.segmentation_head.out_channels = 7  # Modify the segmentation head

# Loss and Optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=5e-5)

# Training loop
def train(model, dataloader, loss_fn, optimizer, epochs=3):
    model.train()
    for epoch in range(epochs):
        for images, masks in dataloader:
            optimizer.zero_grad()
            outputs = model(images).logits  # Adjust depending on the actual model output
            loss = loss_fn(outputs, masks)
            loss.backward()
            optimizer.step()
            print(f"Epoch {epoch+1}, Loss: {loss.item()}")

# Inference
def predict(model, image):
    model.eval()
    with torch.no_grad():
        prediction = model(image).logits
    return prediction

# Train the model
train(model, dataloader, loss_fn, optimizer)

# Dummy inference example
sample_image, _ = dataset[0]
sample_image = sample_image.unsqueeze(0)  # Add batch dimension
prediction = predict(model, sample_image)
predicted_mask = torch.argmax(prediction, dim=1)  # Convert probabilities to class labels

print("Predicted Mask Shape:", predicted_mask.shape)

AttributeError: 'SamModel' object has no attribute 'segmentation_head'

In [3]:
import roboflow
from roboflow import Roboflow

roboflow.login()

rf = Roboflow()

project = rf.workspace("hashira-fhxpj").project("mri-brain-tumor")
dataset = project.version(1).download("coco")

visit https://app.roboflow.com/auth-cli to get your authentication token.
Paste the authentication token here: ········
loading Roboflow workspace...
loading Roboflow project...


Downloading Dataset Version Zip in MRI-BRAIN-TUMOR-1 to coco:: 100%|█████████████| 1525/1525 [00:00<00:00, 1579.86it/s]





Extracting Dataset Version Zip to MRI-BRAIN-TUMOR-1 in coco:: 100%|██████████████████| 85/85 [00:00<00:00, 2574.59it/s]
