In [None]:
#This is Good Practioce for the moment

!rm -rf /opt/conda/lib/python3.10/site-packages/fsspec*
!pip install fsspec==2024.6.0 --force-reinstall --no-deps
!pip install opencv-python

In [None]:
#install also to vizualize figures
!sudo apt-get update
!sudo apt-get install -y libgl1-mesa-glx
!sudo apt-get install -y libglib2.0-0

In [153]:
# Set the root directory for your Kaggle files
rd = './kaggle-files'

# Load the main CSV file
df = pd.read_csv(f'{rd}/train.csv')
df = df.fillna(-100)  # Use -100 to indicate missing labels

# Map the labels to integers for multi-class classification
label2id = {'Normal/Mild': 0, 'Moderate': 1, 'Severe': 2}
df.replace(label2id, inplace=True)

# Load the coordinates data
coordinates_df = pd.read_csv(f'{rd}/dfc_updated.csv')
# Keep only rows where 'slice_number' is not NaN
coordinates_df = coordinates_df.dropna(subset=['slice_number'])
coordinates_df['slice_number'] = coordinates_df['slice_number'].astype(int)

# Load the series descriptions
series_description_df = pd.read_csv(f'{rd}/train_series_descriptions.csv')
series_description_df['series_description'] = series_description_df['series_description'].str.replace('T2/STIR', 'T2_STIR')

# Define constants
SERIES_DESCRIPTIONS = ['Sagittal T1', 'Sagittal T2_STIR', 'Axial T2']


In [154]:
class LumbarSpineDataset(Dataset):
    def __init__(self, df, coordinates_df, series_description_df, root_dir, transform=None):
        self.df = df
        self.coordinates_df = coordinates_df
        self.series_description_df = series_description_df
        self.root_dir = root_dir  # The root directory where images are stored
        self.transform = transform

        # Get the list of study_ids
        self.study_ids = self.df['study_id'].unique()

        # List of label columns (assuming all columns except 'study_id' are labels)
        self.label_columns = [col for col in df.columns if col != 'study_id']

        # Prepare a mapping for images and annotations
        self.study_image_paths = self._prepare_image_paths()

        # Create a mapping from study_id to labels
        self.labels_dict = self._prepare_labels()

    def _prepare_image_paths(self):
        study_image_paths = {}
        for study_id in self.study_ids:
            study_image_paths[study_id] = {}
            for series_description in SERIES_DESCRIPTIONS:
                series_description_clean = series_description.replace('/', '_')
                image_dir = os.path.join(self.root_dir, 'cvt_png', str(study_id), series_description_clean)
                if os.path.exists(image_dir):
                    # Get all images in the directory
                    image_paths = sorted(glob.glob(os.path.join(image_dir, '*.png')))
                    study_image_paths[study_id][series_description] = image_paths
                else:
                    # Handle missing series
                    study_image_paths[study_id][series_description] = []
        return study_image_paths

    def _prepare_labels(self):
        labels_dict = {}
        for idx, row in self.df.iterrows():
            study_id = row['study_id']
            labels = []
            for col in self.label_columns:
                label = row[col]
                if pd.isnull(label) or label == -100:
                    label = -100  # Use -100 for missing labels (ignore_index)
                else:
                    label = int(label)
                labels.append(label)
            labels_dict[study_id] = labels
        return labels_dict

    def __len__(self):
        return len(self.study_ids)

    def __getitem__(self, idx):
        study_id = self.study_ids[idx]
        images = {}
        annotations = {}

        # Load images for each series description
        for series_description in SERIES_DESCRIPTIONS:
            image_paths = self.study_image_paths[study_id][series_description]
            series_images = []
            for img_path in image_paths:
                img = Image.open(img_path).convert('L')  # Convert to grayscale
                if self.transform:
                    img = self.transform(img)
                series_images.append(img)
            if series_images:
                # Stack images along the depth dimension
                series_tensor = torch.stack(series_images, dim=0)  # Shape: [num_slices, 1, H, W]
            else:
                # Handle missing images
                series_tensor = torch.zeros((1, 1, 512, 512))  # Placeholder tensor
            images[series_description] = series_tensor

        # Get labels for the study_id
        labels = self.labels_dict[study_id]
        labels_tensor = torch.tensor(labels, dtype=torch.long)  # Use long dtype for CrossEntropyLoss

        # Get annotations for the study_id (if needed)
        study_annotations = self.coordinates_df[self.coordinates_df['study_id'] == study_id]
        for _, row in study_annotations.iterrows():
            condition = row['condition']
            level = row['level']
            x = row['x_scaled']
            y = row['y_scaled']
            series_description = row['series_description']
            slice_number = int(row['slice_number'])
            key = f"{condition}_{level}"
            if key not in annotations:
                annotations[key] = {}
            if series_description not in annotations[key]:
                annotations[key][series_description] = []
            annotations[key][series_description].append({
                'x': x,
                'y': y,
                'slice_number': slice_number
            })

        # Return a dictionary containing images, labels, and annotations
        sample = {
            'study_id': study_id,
            'images': images,
            'labels': labels_tensor,
            'annotations': annotations
        }

        return sample


In [155]:
# Define any transformations if needed
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Adjust mean and std if necessary
])

# Instantiate the dataset
train_dataset = LumbarSpineDataset(
    df=df,
    coordinates_df=coordinates_df,
    series_description_df=series_description_df,
    root_dir='./rsna_output',  # Adjust the path as needed
    transform=transform
)

# Create a DataLoader
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=1,  # Adjust batch size as needed
    shuffle=True,
    num_workers=4,  # Adjust based on your system
    pin_memory=True
)


In [156]:
# Define the ResNet feature extractor
class ResNetFeatureExtractor(nn.Module):
    def __init__(self, in_channels=10):
        super(ResNetFeatureExtractor, self).__init__()
        resnet = models.resnet18(pretrained=True)
        # Modify the first convolutional layer to accept in_channels
        resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # Extract layers up to layer4 (exclude avgpool and fc layers)
        self.features = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4
        )

    def forward(self, x):
        x = self.features(x)
        return x  # Output shape: [batch_size, 512, H, W]

# Define the main model
class MultiSeriesSpineModel(nn.Module):
    def __init__(self, num_conditions=25, num_classes=3):
        super(MultiSeriesSpineModel, self).__init__()

        # Feature extractors for each MRI series
        self.cnn_sagittal_t1 = ResNetFeatureExtractor(in_channels=10)
        self.cnn_sagittal_t2_stir = ResNetFeatureExtractor(in_channels=10)
        self.cnn_axial_t2 = ResNetFeatureExtractor(in_channels=10)

        # Define attention layers for each series
        self.attention_sagittal_t1 = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.attention_sagittal_t2_stir = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.attention_axial_t2 = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )

        # Define the final classification layers
        combined_feature_size = 512 * 3  # Since we're concatenating features from three models

        self.fc1 = nn.Linear(combined_feature_size, 512)
        self.fc2 = nn.Linear(512, num_conditions * num_classes)  # Output layer

    def forward(self, sagittal_t1, sagittal_t2_stir, axial_t2):
        # Forward pass through each ResNet18 model
        features_sagittal_t1 = self.cnn_sagittal_t1(sagittal_t1)  # Shape: [batch_size, 512, H, W]
        features_sagittal_t2_stir = self.cnn_sagittal_t2_stir(sagittal_t2_stir)
        features_axial_t2 = self.cnn_axial_t2(axial_t2)

        # Generate attention maps
        attention_map_t1 = self.attention_sagittal_t1(features_sagittal_t1)  # Shape: [batch_size, 1, H, W]
        attention_map_t2_stir = self.attention_sagittal_t2_stir(features_sagittal_t2_stir)
        attention_map_axial = self.attention_axial_t2(features_axial_t2)

        # Apply attention
        features_sagittal_t1 = features_sagittal_t1 * attention_map_t1  # Element-wise multiplication
        features_sagittal_t2_stir = features_sagittal_t2_stir * attention_map_t2_stir
        features_axial_t2 = features_axial_t2 * attention_map_axial

        # Global average pooling
        features_sagittal_t1 = F.adaptive_avg_pool2d(features_sagittal_t1, (1, 1)).view(features_sagittal_t1.size(0), -1)
        features_sagittal_t2_stir = F.adaptive_avg_pool2d(features_sagittal_t2_stir, (1, 1)).view(features_sagittal_t2_stir.size(0), -1)
        features_axial_t2 = F.adaptive_avg_pool2d(features_axial_t2, (1, 1)).view(features_axial_t2.size(0), -1)

        # Concatenate features
        combined_features = torch.cat([features_sagittal_t1, features_sagittal_t2_stir, features_axial_t2], dim=1)

        # Pass through final classification layers
        x = F.relu(self.fc1(combined_features))
        x = self.fc2(x)  # Shape: [batch_size, num_conditions * num_classes]
        x = x.view(-1, num_conditions, num_classes)  # Reshape to [batch_size, num_conditions, num_classes]
        return x  # Return logits


In [157]:
# Resample slices function
def resample_slices(image_tensor, target_slices=10):
    """
    Resample the number of slices to match the target number of slices.
    """
    current_slices = image_tensor.shape[0]

    if current_slices == target_slices:
        return image_tensor  # No need to resample

    # If more slices, downsample to the target number
    if current_slices > target_slices:
        indices = torch.linspace(0, current_slices - 1, target_slices).long()
        return image_tensor[indices]

    # If fewer slices, upsample by interpolation
    image_tensor = image_tensor.permute(1, 0, 2, 3).unsqueeze(0)  # Shape: [1, channels, slices, H, W]
    image_tensor_resized = F.interpolate(
        image_tensor,
        size=(target_slices, image_tensor.shape[3], image_tensor.shape[4]),
        mode='trilinear',
        align_corners=False
    )
    return image_tensor_resized.squeeze(0).permute(1, 0, 2, 3)  # Shape: [slices, channels, H, W]


In [160]:
# Instantiate the model
num_conditions = len(train_dataset.label_columns)
num_classes = 3
model = MultiSeriesSpineModel(num_conditions=num_conditions, num_classes=num_classes)

# Move the model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=-100)  # Use ignore_index to ignore missing labels
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)




In [None]:
# Training loop
num_epochs = 10  # Define the number of epochs
model.train()

for epoch in range(num_epochs):
    epoch_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")

    for batch in progress_bar:
        # Extract images and labels from the batch
        images = batch['images']
        labels = batch['labels']  # Tensor of shape [num_conditions]

        # Get the image tensors
        sagittal_t1 = images['Sagittal T1'].squeeze(0)  # Shape: [num_slices, 1, H, W]
        sagittal_t2_stir = images['Sagittal T2_STIR'].squeeze(0)
        axial_t2 = images['Axial T2'].squeeze(0)

        # Resample slices to 10
        sagittal_t1 = resample_slices(sagittal_t1, target_slices=10)
        sagittal_t2_stir = resample_slices(sagittal_t2_stir, target_slices=10)
        axial_t2 = resample_slices(axial_t2, target_slices=10)

        # Remove singleton channel dimension if present
        sagittal_t1 = sagittal_t1.squeeze(1)  # Shape: [10, H, W]
        sagittal_t2_stir = sagittal_t2_stir.squeeze(1)
        axial_t2 = axial_t2.squeeze(1)

        # Add batch dimension and move to device
        sagittal_t1 = sagittal_t1.unsqueeze(0).to(device)  # Shape: [1, 10, H, W]
        sagittal_t2_stir = sagittal_t2_stir.unsqueeze(0).to(device)
        axial_t2 = axial_t2.unsqueeze(0).to(device)

        # Prepare labels tensor and move to device
        labels_tensor = labels.unsqueeze(0).to(device)  # Shape: [1, num_conditions]

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(sagittal_t1, sagittal_t2_stir, axial_t2)  # Shape: [1, num_conditions, num_classes]

        # Reshape outputs and labels
        outputs = outputs.view(-1, num_classes)       # Shape: [num_conditions, num_classes]
        labels_tensor = labels_tensor.view(-1)        # Shape: [num_conditions]

        # Compute loss
        total_loss = criterion(outputs, labels_tensor)

        # Backward pass
        total_loss.backward()

        # Optimizer step
        optimizer.step()

        # Update loss
        epoch_loss += total_loss.item()

        # Update progress bar
        progress_bar.set_postfix({'Loss': f'{total_loss.item():.4f}'})

    # Epoch summary
    avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}] Average Loss: {avg_loss:.4f}")


Epoch 1/10:  11%|█         | 222/1975 [00:20<02:21, 12.38batch/s, Loss=0.6622]


### Objective:
We want to plot an annotated axial slice image from our dataset. The annotations come from the `coordinates_df`, which contains x, y coordinates and additional information about the study, including the `series_id`, `instance_number`, `condition`, and `level`. These annotations represent the specific slices and the associated condition-level severity we're trying to classify/estimate.

### Key Points:
1. **Data Sources**:
   - **`df`**: This contains the labels for `condition` and `level` across different spinal areas for each `study_id`.
   - **`coordinates_df`**: This contains the x, y coordinates, `series_id`, `instance_number`, `condition`, and `level` related to each `study_id`.
   - **`series_description_df`**: This maps the `series_id` to its respective `series_description` (e.g., 'Axial T2', 'Sagittal T1').

2. **Image Path Mapping**:
   - From `coordinates_df`, we need to extract the `study_id`, `series_id`, and `instance_number` to locate the corresponding axial image. 
   - The image path is generated using:
     ```python
     image_path = f'./rsna_output/cvt_png/{study_id}/{series_description}/{instance_number:03d}.png'
     ```
     where `series_description` is derived from the `series_id` using the `series_description_df`.

3. **DataLoader Responsibilities**:
   - The DataLoader needs to provide the required information (`study_id`, `series_id`, `instance_number`, `x`, `y`) to correctly map images and annotations.
   - For slices without annotations, the model should focus on 'no annotation' data.

### Process Flow:
1. **Fetch Image and Annotations**:
   - For each study (`study_id`), find the `x`, `y` coordinates from `coordinates_df`.
   - Get the corresponding `series_id` and map it to a `series_description` using `series_description_df`.
   - Locate the slice image using `series_description` and `instance_number`.

2. **Plotting**:
   - Display the axial slice image with a bounding box drawn around the `x`, `y` coordinates for the annotation.
   - Display the label for the corresponding `condition` and `level`.



In [None]:
# from torchviz import make_dot
# from PIL import Image
# import matplotlib.pyplot as plt

# # Create dummy data to simulate model input
# batch_size = 2
# dummy_sagittal_t1 = torch.randn(batch_size, 10, 512, 512)  # 10 slices for Sagittal T1
# dummy_sagittal_t2_stir = torch.randn(batch_size, 10, 512, 512)  # 10 slices for Sagittal T2/STIR
# dummy_axial_t2 = torch.randn(batch_size, 10, 512, 512)  # 10 slices for Axial T2

# # Pass through the model to get a forward pass
# condition_pred, coord_pred = model(dummy_sagittal_t1, dummy_sagittal_t2_stir, dummy_axial_t2)

# # Create the computational graph
# dot = make_dot((condition_pred, coord_pred), params=dict(model.named_parameters()))

# # Render to a file and display it
# dot.render("model_diagram", format="png")  # Save as PNG

# # Load and display the image
# img = Image.open("model_diagram.png")
# plt.figure(figsize=(10, 10))  # Increase the figure size for better clarity
# plt.imshow(img)
# plt.axis('off')  # Hide axes for clarity
# plt.show()