# Download Train Set

In [None]:
from google.colab import drive
import os

SAVE_ON_DRIVE = False
drive.mount('/content/drive')
dataset_path = '/content/drive/MyDrive/LoveDA/Train'
if (os.path.exists("./Train") == False):
    if (os.path.exists("/content/drive/MyDrive/LoveDA/Train.zip")):
        print("Dataset available on own drive, unzipping...")
        !unzip -q /content/drive/MyDrive/LoveDA/Train.zip -d ./
    else:
        print("Downloading dataset...")
        !wget -O Train.zip "https://zenodo.org/records/5706578/files/Train.zip?download=1"
        if(SAVE_ON_DRIVE):
            print("Saving dataset on drive...")
            !cp Train.zip /content/drive/MyDrive/LoveDA/
        !unzip -q Train.zip -d ./

else:
    print("Dataset already in local")

Mounted at /content/drive
Dataset available on own drive, unzipping...


# MixMask Utils

In [None]:
# extract classes from source masks
def extract_classes_from_mask(mask):
  return torch.unique(mask)

# randomly select classes to mix
def select_classes_for_mix(classes):
    nclasses = classes.shape[0]
    num_classes_to_select = (nclasses + nclasses % 2) // 2  # Take half of the classes (rounded up if odd)
    selected_classes = classes[torch.randint(0, nclasses, (num_classes_to_select,))]
    return selected_classes

# create masks for mixed classes
# generate a mask where only pixels from the selected classes are retained
def generate_class_mask(labels, selected_classes):
    mask = torch.zeros_like(labels)
    for class_id in selected_classes:
        mask[labels == class_id] = 1 # set to 1 where the class is present
    return mask

# Extract Masks

### Dirs

In [None]:
from PIL import Image
import os
import torch
from tqdm import tqdm

def pil_loader(path, color_type):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert(color_type)

os.makedirs("MixMasks", exist_ok=True)

masks_path = "./Train/Urban/masks_png"

### Actual Extraction

In [None]:
import numpy as np

loop = tqdm(os.listdir(masks_path))

for filename in loop:
    loop.set_description(f"Processing {filename}") # update desc with filename
    mask = pil_loader(os.path.join(masks_path, filename), "L")
    mask = torch.from_numpy(np.array(mask))

    MixMask = generate_class_mask(mask, select_classes_for_mix(extract_classes_from_mask(mask))).unsqueeze(0)
    torch.save(MixMask, f"MixMasks/{filename}.pt")

### Copy To Drive

In [None]:
if len(os.listdir('./Train/Urban/masks_png')) == len(os.listdir('./MixMasks')):
    print("Copying to drive...")
    !zip -r MixMasks.zip MixMasks
    !cp MixMasks.zip /content/drive/MyDrive/LoveDA/