In [20]:
from google.colab import drive

drive.mount('/content/drive/')

%cd "/content/drive/MyDrive/CONSEGNA_ML/"


Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).
/content/drive/.shortcut-targets-by-id/1vnO-rlfCpoeqYPLaBVG66OUTYbxtehIA/CONSEGNA_ML


#IMPLEMENT HERE YOUR CODE

In [21]:
from torchvision import transforms
import torch

##IMPLEMENT HERE THE FUNCTION TO LOAD YOUR MODEL
For example, here we use a single convolutional layer.

In [22]:
from models import DeepLabV3, CNN_7_Layers, DeepLabV3Lite

def load_model():
  model_instance = DeepLabV3Lite()
  # checkpoint = torch.load('checkpoint_epoch_50.pth')
  checkpoint = torch.load('checkpoint_epoch_15.pth', map_location=torch.device('cpu'))
  model_instance.load_state_dict(checkpoint['model_state_dict'])
  model_instance.eval()
  return model_instance

##IMPLEMENT HERE YOUR PREDICT FUNCTION

In [23]:
from utils import convert_onehotencoding_to_rgb

def predict(model, X):
    # accept (3, H, W) or (1, 3, H, W), values in [0, 1]
    # returns an image (3,H,W) that went through the model (with pixel segmentation)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    # ensure dimension of the model (1, 3, H, W)
    transform = transforms.Compose([
        transforms.ToTensor(),  # Converts to (3, H, W), values in [0,1]
    ])

    image_tensor = transform(X).unsqueeze(0).to(device)  # (1, 3, H, W)

    with torch.no_grad():
        output = model(image_tensor)           # (1, 9, H, W)
        predicted = output.argmax(dim=1)       # (1, H, W)
        predicted = predicted.squeeze(0).cpu().numpy()  # (H, W) - class indices

    predicted_rgb = convert_onehotencoding_to_rgb(predicted)

    return predicted_rgb

#DO NOT MODIFY THE CODE BELOW!

This is exactly the code we run for the final test.

After implementing the previous functions, run this code to verify that it works

In [24]:
import os
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

# Metrics
def compute_iou(mask1, mask2, label):
  intersection = np.sum((mask1 == label) & (mask2 == label))
  union = np.sum((mask1 == label) | (mask2 == label))
  if union == 0:
    return np.nan
  return intersection / union
def compute_all_iou(mask1, mask2, num_labels=8):
  iou_scores = np.zeros((num_labels))
  for label in range(num_labels):
    iou = compute_iou(mask1, mask2, label+1) # we skip the background label
    iou_scores[label] = iou
  return iou_scores


# Run YOUR LOAD_MODEL FUNCTION
model = load_model()

# Main loop
test_dir = "./train"
# test_dir = '../../datasets/esame_ml/train' # we will change this path with that of the private test set directory
samples = os.listdir(test_dir)
IOUs = np.zeros((len(samples), 8))
verbose = False

for i, subdir in tqdm(enumerate(samples), desc="Processing samples"):
    subdir_path = os.path.join(test_dir, subdir)

    if os.path.isdir(subdir_path):
        # Get the data paths
        rgb_path = os.path.join(subdir_path, 'rgb.jpg')
        labels_path = os.path.join(subdir_path, 'labels.png')

        if os.path.exists(rgb_path) and os.path.exists(labels_path):
            if verbose:
                print(f"Processing subdirectory: {subdir}")

            try:  # ATTENTION: any error occurring in this try-catch means that the corresponding IOUs are evaluated as ZERO

                # Open images
                rgb_image = Image.open(rgb_path)
                rgb_array = np.asarray(rgb_image).copy()
                labels_image = Image.open(labels_path).copy()
                labels_array = np.asarray(labels_image)
                if verbose:
                    print(f"  Loaded {rgb_path} and {labels_path}")

                # Run YOUR PREDICT FUNCTION
                predicted_labels_array = predict(model, rgb_array)

                # Evaluate the IOU metric
                IOUs[i,:] = compute_all_iou(labels_array, predicted_labels_array)

                if verbose:
                    labels_vals = np.unique(np.asarray(labels_image))
                    print(f"  Unique labels values: {labels_vals}")
                    predicted_labels_vals = np.unique(np.asarray(predicted_labels_array))
                    print(f"  Unique predicted labels values: {predicted_labels_vals}")

                    plt.subplot(1, 3, 1)
                    plt.imshow(rgb_image)
                    plt.subplot(1, 3, 2)
                    plt.imshow(labels_image)
                    plt.subplot(1, 3, 3)
                    plt.imshow(predicted_labels_array)
                    plt.show()

                rgb_image.close()
                labels_image.close()

            except FileNotFoundError:
                print(f"  Error: Could not find image files in {subdir_path}")
            except Exception as e:
                print(f"  Error processing images in {subdir_path}: {e}")
        else:
            print(f"  Skipping subdirectory {subdir}: rgb.jpg or labels.png not found.")

score = np.nanmean(IOUs)
print(f"\nFinal competition score: {score}")

Processing samples: 2it [00:01,  1.55it/s]

  Error processing images in ./train/0287: Expected shape (9, M, N)


Processing samples: 3it [00:02,  1.27it/s]

  Error processing images in ./train/1171: Expected shape (9, M, N)


Processing samples: 4it [00:02,  1.32it/s]

  Error processing images in ./train/0099: Expected shape (9, M, N)


Processing samples: 5it [00:03,  1.30it/s]

  Error processing images in ./train/0418: Expected shape (9, M, N)


Processing samples: 6it [00:04,  1.29it/s]

  Error processing images in ./train/1172: Expected shape (9, M, N)


Processing samples: 7it [00:05,  1.32it/s]

  Error processing images in ./train/0506: Expected shape (9, M, N)


Processing samples: 8it [00:06,  1.34it/s]

  Error processing images in ./train/0195: Expected shape (9, M, N)


Processing samples: 9it [00:06,  1.33it/s]

  Error processing images in ./train/1223: Expected shape (9, M, N)


Processing samples: 10it [00:07,  1.31it/s]

  Error processing images in ./train/1219: Expected shape (9, M, N)


Processing samples: 11it [00:08,  1.31it/s]

  Error processing images in ./train/1261: Expected shape (9, M, N)


Processing samples: 12it [00:09,  1.32it/s]

  Error processing images in ./train/1170: Expected shape (9, M, N)


Processing samples: 13it [00:09,  1.32it/s]

  Error processing images in ./train/0711: Expected shape (9, M, N)


Processing samples: 14it [00:10,  1.16it/s]

  Error processing images in ./train/0180: Expected shape (9, M, N)


Processing samples: 15it [00:11,  1.08it/s]

  Error processing images in ./train/1259: Expected shape (9, M, N)


Processing samples: 16it [00:13,  1.17s/it]

  Error processing images in ./train/1065: Expected shape (9, M, N)


Processing samples: 17it [00:14,  1.19s/it]

  Error processing images in ./train/0161: Expected shape (9, M, N)


Processing samples: 18it [00:16,  1.20s/it]

  Error processing images in ./train/1333: Expected shape (9, M, N)


Processing samples: 19it [00:17,  1.35s/it]

  Error processing images in ./train/1258: Expected shape (9, M, N)


Processing samples: 20it [00:19,  1.44s/it]

  Error processing images in ./train/0544: Expected shape (9, M, N)


Processing samples: 21it [00:21,  1.48s/it]

  Error processing images in ./train/1325: Expected shape (9, M, N)


Processing samples: 22it [00:22,  1.59s/it]

  Error processing images in ./train/1343: Expected shape (9, M, N)


Processing samples: 23it [00:25,  1.73s/it]

  Error processing images in ./train/0268: Expected shape (9, M, N)


Processing samples: 24it [00:26,  1.75s/it]

  Error processing images in ./train/1159: Expected shape (9, M, N)


Processing samples: 25it [00:28,  1.67s/it]

  Error processing images in ./train/0641: Expected shape (9, M, N)


Processing samples: 26it [00:29,  1.59s/it]

  Error processing images in ./train/1176: Expected shape (9, M, N)


Processing samples: 27it [00:31,  1.52s/it]

  Error processing images in ./train/1215: Expected shape (9, M, N)


Processing samples: 28it [00:32,  1.44s/it]

  Error processing images in ./train/1308: Expected shape (9, M, N)


Processing samples: 29it [00:33,  1.34s/it]

  Error processing images in ./train/0092: Expected shape (9, M, N)


Processing samples: 30it [00:34,  1.18s/it]

  Error processing images in ./train/0291: Expected shape (9, M, N)


Processing samples: 31it [00:35,  1.08s/it]

  Error processing images in ./train/0309: Expected shape (9, M, N)


Processing samples: 32it [00:35,  1.00it/s]

  Error processing images in ./train/1226: Expected shape (9, M, N)


Processing samples: 33it [00:36,  1.07it/s]

  Error processing images in ./train/1240: Expected shape (9, M, N)


Processing samples: 34it [00:37,  1.13it/s]

  Error processing images in ./train/1328: Expected shape (9, M, N)


Processing samples: 35it [00:38,  1.15it/s]

  Error processing images in ./train/0212: Expected shape (9, M, N)


Processing samples: 36it [00:39,  1.17it/s]

  Error processing images in ./train/0152: Expected shape (9, M, N)


Processing samples: 37it [00:39,  1.20it/s]

  Error processing images in ./train/0004: Expected shape (9, M, N)


Processing samples: 38it [00:40,  1.22it/s]

  Error processing images in ./train/0638: Expected shape (9, M, N)


Processing samples: 39it [00:41,  1.23it/s]

  Error processing images in ./train/1059: Expected shape (9, M, N)


Processing samples: 40it [00:42,  1.23it/s]

  Error processing images in ./train/0211: Expected shape (9, M, N)


Processing samples: 41it [00:44,  1.15s/it]

  Error processing images in ./train/0199: Expected shape (9, M, N)


Processing samples: 42it [00:46,  1.38s/it]

  Error processing images in ./train/0247: Expected shape (9, M, N)


Processing samples: 43it [00:47,  1.36s/it]

  Error processing images in ./train/0517: Expected shape (9, M, N)


Processing samples: 44it [00:48,  1.27s/it]

  Error processing images in ./train/0359: Expected shape (9, M, N)


Processing samples: 45it [00:49,  1.13s/it]

  Error processing images in ./train/0446: Expected shape (9, M, N)


Processing samples: 46it [00:50,  1.02s/it]

  Error processing images in ./train/1295: Expected shape (9, M, N)


Processing samples: 47it [00:50,  1.04it/s]

  Error processing images in ./train/0683: Expected shape (9, M, N)


Processing samples: 48it [00:51,  1.10it/s]

  Error processing images in ./train/0699: Expected shape (9, M, N)


Processing samples: 49it [00:52,  1.15it/s]

  Error processing images in ./train/0192: Expected shape (9, M, N)


Processing samples: 50it [00:53,  1.20it/s]

  Error processing images in ./train/1060: Expected shape (9, M, N)


Processing samples: 51it [00:54,  1.21it/s]

  Error processing images in ./train/0053: Expected shape (9, M, N)


Processing samples: 52it [00:54,  1.23it/s]

  Error processing images in ./train/0757: Expected shape (9, M, N)


Processing samples: 53it [00:55,  1.23it/s]

  Error processing images in ./train/1118: Expected shape (9, M, N)


Processing samples: 54it [00:56,  1.24it/s]

  Error processing images in ./train/0676: Expected shape (9, M, N)


Processing samples: 55it [00:57,  1.25it/s]

  Error processing images in ./train/0834: Expected shape (9, M, N)


Processing samples: 56it [00:57,  1.25it/s]

  Error processing images in ./train/1102: Expected shape (9, M, N)


Processing samples: 57it [00:59,  1.09it/s]

  Error processing images in ./train/1312: Expected shape (9, M, N)


Processing samples: 58it [01:00,  1.02s/it]

  Error processing images in ./train/0320: Expected shape (9, M, N)


Processing samples: 59it [01:01,  1.10s/it]

  Error processing images in ./train/1296: Expected shape (9, M, N)


Processing samples: 60it [01:02,  1.15s/it]

  Error processing images in ./train/0001: Expected shape (9, M, N)


Processing samples: 61it [01:03,  1.07s/it]

  Error processing images in ./train/0398: Expected shape (9, M, N)


Processing samples: 62it [01:04,  1.02it/s]

  Error processing images in ./train/0007: Expected shape (9, M, N)


Processing samples: 63it [01:05,  1.07it/s]

  Error processing images in ./train/0669: Expected shape (9, M, N)


Processing samples: 64it [01:06,  1.12it/s]

  Error processing images in ./train/0391: Expected shape (9, M, N)


Processing samples: 65it [01:07,  1.16it/s]

  Error processing images in ./train/1321: Expected shape (9, M, N)


Processing samples: 66it [01:07,  1.18it/s]

  Error processing images in ./train/0282: Expected shape (9, M, N)


Processing samples: 67it [01:08,  1.19it/s]

  Error processing images in ./train/0057: Expected shape (9, M, N)


Processing samples: 68it [01:09,  1.21it/s]

  Error processing images in ./train/0766: Expected shape (9, M, N)


Processing samples: 69it [01:10,  1.23it/s]

  Error processing images in ./train/0658: Expected shape (9, M, N)


Processing samples: 70it [01:11,  1.24it/s]

  Error processing images in ./train/1342: Expected shape (9, M, N)


Processing samples: 71it [01:11,  1.25it/s]

  Error processing images in ./train/0021: Expected shape (9, M, N)


Processing samples: 72it [01:12,  1.27it/s]

  Error processing images in ./train/0333: Expected shape (9, M, N)


Processing samples: 73it [01:13,  1.21it/s]

  Error processing images in ./train/0087: Expected shape (9, M, N)


Processing samples: 74it [01:14,  1.05it/s]

  Error processing images in ./train/1140: Expected shape (9, M, N)


Processing samples: 75it [01:16,  1.06s/it]

  Error processing images in ./train/1201: Expected shape (9, M, N)


Processing samples: 76it [01:17,  1.13s/it]

  Error processing images in ./train/0436: Expected shape (9, M, N)


Processing samples: 77it [01:18,  1.12s/it]

  Error processing images in ./train/1081: Expected shape (9, M, N)


Processing samples: 78it [01:19,  1.01s/it]

  Error processing images in ./train/0721: Expected shape (9, M, N)


Processing samples: 79it [01:19,  1.06it/s]

  Error processing images in ./train/0729: Expected shape (9, M, N)


Processing samples: 80it [01:20,  1.12it/s]

  Error processing images in ./train/0710: Expected shape (9, M, N)


Processing samples: 81it [01:21,  1.16it/s]

  Error processing images in ./train/1227: Expected shape (9, M, N)


Processing samples: 82it [01:22,  1.19it/s]

  Error processing images in ./train/0011: Expected shape (9, M, N)


Processing samples: 83it [01:23,  1.22it/s]

  Error processing images in ./train/0513: Expected shape (9, M, N)


Processing samples: 84it [01:23,  1.23it/s]

  Error processing images in ./train/1185: Expected shape (9, M, N)


Processing samples: 85it [01:24,  1.22it/s]

  Error processing images in ./train/1297: Expected shape (9, M, N)


Processing samples: 86it [01:25,  1.24it/s]

  Error processing images in ./train/0240: Expected shape (9, M, N)


Processing samples: 87it [01:26,  1.25it/s]

  Error processing images in ./train/0061: Expected shape (9, M, N)


Processing samples: 88it [01:27,  1.26it/s]

  Error processing images in ./train/0379: Expected shape (9, M, N)


Processing samples: 89it [01:27,  1.25it/s]

  Error processing images in ./train/1192: Expected shape (9, M, N)


Processing samples: 90it [01:29,  1.11it/s]

  Error processing images in ./train/0125: Expected shape (9, M, N)


Processing samples: 91it [01:30,  1.03s/it]

  Error processing images in ./train/1337: Expected shape (9, M, N)


Processing samples: 92it [01:31,  1.11s/it]

  Error processing images in ./train/0019: Expected shape (9, M, N)


Processing samples: 93it [01:32,  1.14s/it]

  Error processing images in ./train/1220: Expected shape (9, M, N)


Processing samples: 94it [01:33,  1.06s/it]

  Error processing images in ./train/0269: Expected shape (9, M, N)


Processing samples: 95it [01:34,  1.03it/s]

  Error processing images in ./train/0169: Expected shape (9, M, N)


Processing samples: 96it [01:35,  1.09it/s]

  Error processing images in ./train/0691: Expected shape (9, M, N)


Processing samples: 97it [01:36,  1.15it/s]

  Error processing images in ./train/1329: Expected shape (9, M, N)


Processing samples: 98it [01:36,  1.18it/s]

  Error processing images in ./train/0312: Expected shape (9, M, N)


Processing samples: 99it [01:37,  1.20it/s]

  Error processing images in ./train/0428: Expected shape (9, M, N)


Processing samples: 100it [01:38,  1.21it/s]

  Error processing images in ./train/0293: Expected shape (9, M, N)


Processing samples: 101it [01:39,  1.24it/s]

  Error processing images in ./train/0809: Expected shape (9, M, N)


Processing samples: 102it [01:40,  1.25it/s]

  Error processing images in ./train/0279: Expected shape (9, M, N)


Processing samples: 103it [01:40,  1.25it/s]

  Error processing images in ./train/0545: Expected shape (9, M, N)


Processing samples: 104it [01:41,  1.24it/s]

  Error processing images in ./train/1161: Expected shape (9, M, N)


Processing samples: 105it [01:42,  1.24it/s]

  Error processing images in ./train/0175: Expected shape (9, M, N)


Processing samples: 106it [01:43,  1.21it/s]

  Error processing images in ./train/1289: Expected shape (9, M, N)


Processing samples: 107it [01:44,  1.03s/it]

  Error processing images in ./train/0689: Expected shape (9, M, N)


Processing samples: 108it [01:46,  1.27s/it]

  Error processing images in ./train/1166: Expected shape (9, M, N)


Processing samples: 109it [01:48,  1.40s/it]

  Error processing images in ./train/0396: Expected shape (9, M, N)


Processing samples: 110it [01:49,  1.29s/it]

  Error processing images in ./train/1106: Expected shape (9, M, N)


Processing samples: 111it [01:50,  1.14s/it]

  Error processing images in ./train/0696: Expected shape (9, M, N)


Processing samples: 112it [01:51,  1.04s/it]

  Error processing images in ./train/0059: Expected shape (9, M, N)


Processing samples: 113it [01:52,  1.09s/it]

  Error processing images in ./train/1184: Expected shape (9, M, N)


Processing samples: 114it [01:53,  1.18s/it]

  Error processing images in ./train/1076: Expected shape (9, M, N)


Processing samples: 115it [01:54,  1.23s/it]

  Error processing images in ./train/0394: Expected shape (9, M, N)


Processing samples: 116it [01:56,  1.25s/it]

  Error processing images in ./train/0362: Expected shape (9, M, N)


Processing samples: 117it [01:57,  1.11s/it]

  Error processing images in ./train/0397: Expected shape (9, M, N)


Processing samples: 118it [01:57,  1.02s/it]

  Error processing images in ./train/1224: Expected shape (9, M, N)


Processing samples: 119it [01:58,  1.03it/s]

  Error processing images in ./train/0158: Expected shape (9, M, N)


Processing samples: 120it [01:59,  1.00s/it]

  Error processing images in ./train/1272: Expected shape (9, M, N)


Processing samples: 121it [02:01,  1.09s/it]

  Error processing images in ./train/1202: Expected shape (9, M, N)


Processing samples: 122it [02:02,  1.17s/it]

  Error processing images in ./train/0678: Expected shape (9, M, N)


Processing samples: 123it [02:03,  1.19s/it]

  Error processing images in ./train/0051: Expected shape (9, M, N)


Processing samples: 124it [02:04,  1.13s/it]

  Error processing images in ./train/1269: Expected shape (9, M, N)


Processing samples: 125it [02:05,  1.02s/it]

  Error processing images in ./train/0198: Expected shape (9, M, N)


Processing samples: 125it [02:05,  1.01s/it]


KeyboardInterrupt: 

More information useful for

In [None]:
import numpy as np
np.set_printoptions(precision=3, suppress=True)
print(f"All IOUs:\n{IOUs}")
print("Average IOUs for each:")
print(f"- class: {np.nanmean(IOUs, 0)}")
print(f"- image: {np.nanmean(IOUs, 1)}")