# Setup Environemnt

In [None]:
import os

workspace_root = os.environ["SELF_DRIVE_CARLA_WORKSPACE"]
project_workspace = os.path.join(workspace_root, "01-semantic-segmentation")

In [None]:
import torch
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp

# Import your SegmentationDataset class
from src.segmentation_dataset import SegmentationDataset

# Perform train test validation split

In [None]:
import pandas as pd

full_dataset_path= os.path.join(project_workspace, "reduced_dataset.csv")
df = pd.read_csv(full_dataset_path, nrows=None)

### Train

In [None]:
train_mask = [i % 10 < 7 for i in range(len(df))]
train_df = df[train_mask]
train_df.to_csv(os.path.join(project_workspace, "reduced_train_dataset.csv"), index=False)

### Validation

In [None]:
val_mask = [7 <= i % 10 < 9 for i in range(len(df))]
val_df = df[val_mask]
val_df.to_csv(os.path.join(project_workspace, "reduced_val_dataset.csv"), index=False)

### Test

In [None]:
test_mask = [i % 10 == 9 for i in range(len(df))]
test_df = df[test_mask]
test_df.to_csv(os.path.join(project_workspace, "reduced_test_dataset.csv"), index=False)

# Train

In [None]:
# Define model (DeepLabV3+ with pre-trained encoder)
model = smp.DeepLabV3Plus(
    encoder_name="resnet34",  # Choose appropriate encoder (e.g., resnet50, efficientnet-b0)
    encoder_weights="imagenet",  # Load pre-trained weights from ImageNet
    classes=15,  # Number of classes in your dataset
)

# Define loss function (e.g., CrossEntropyLoss for multi-class segmentation)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=0) 

model = model.cuda()

In [None]:
from torchvision import transforms

# Define a transform sequence for image data (adjust as needed)
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert PIL image to PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize RGB channels
    # Add other transformations like random cropping, flipping, etc. (optional)
])

In [None]:
train_dataset = SegmentationDataset(os.path.join(project_workspace, "reduced_train_dataset.csv"), project_workspace,mode="train", transform=transform)
val_dataset = SegmentationDataset(os.path.join(project_workspace, "reduced_val_dataset.csv"), project_workspace, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 5

for epoch in range(epochs):
    model.train()
    for images, labels in train_loader:
        images = images.cuda()
        labels = labels.cuda()
        
        predictions = model(images)
        loss = loss_fn(predictions, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")

In [None]:
# save model
torch.save(model.state_dict(), os.path.join(project_workspace, "model.pth"))

In [None]:
model.eval()

In [None]:
test_image = val_dataset[3][0].unsqueeze(0).cuda()

with torch.no_grad():
    test_output = model(test_image)

predicted_mask = test_output.squeeze().argmax(dim=0)

In [None]:
import matplotlib.pyplot as plt
from src.labels import trainId2label
import numpy as np

# Convert predicted mask to RGB image
predicted_mask_rgb = np.zeros((480, 640, 3))
for h in range(480):
  for w in range(640):
    class_label = predicted_mask[h, w]
    predicted_mask_rgb[h, w] = trainId2label[class_label.item()].color

predicted_mask_rgb = predicted_mask_rgb / 255.0

# Predicted mask (RGB)
plt.imshow(predicted_mask_rgb)
plt.axis("off")
plt.show()