In [3]:
import pandas as pd 
import numpy as np 
import os
import glob

In [21]:
from google.colab import drive
drive.mount('/content/drive/')
os.chdir('/content/drive/MyDrive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


### Prepare folder structure for STEGO training in  /datadrive/pytorch-data/
```
dataset_name
|── imgs
|   ├── train
|   |   |── unique_img_name_1.jpg
|   |   └── unique_img_name_2.jpg
|   └── val
|       |── unique_img_name_3.jpg
|       └── unique_img_name_4.jpg
└── labels
    ├── train
    |   |── unique_img_name_1.png
    |   └── unique_img_name_2.png
    └── val
        |── unique_img_name_3.png
        └── unique_img_name_4.png
```



# Prepare Google Colab Environment

In [None]:
#!git clone https://github.com/mhamilton723/STEGO.git

Cloning into 'STEGO'...
remote: Enumerating objects: 201, done.[K
remote: Counting objects: 100% (32/32), done.[K
remote: Compressing objects: 100% (20/20), done.[K
remote: Total 201 (delta 18), reused 15 (delta 12), pack-reused 169[K
Receiving objects: 100% (201/201), 9.23 MiB | 14.63 MiB/s, done.
Resolving deltas: 100% (98/98), done.


In [1]:
!pip install wget
!pip install torchmetrics
!pip install hydra-core
!pip install git+https://github.com/lucasb-eyer/pydensecrf.git
!pip install omegaconf
!pip install pytorch-lightning
!pip install azureml

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/lucasb-eyer/pydensecrf.git
  Cloning https://github.com/lucasb-eyer/pydensecrf.git to /tmp/pip-req-build-7897peo7
  Running command git clone -q https://github.com/lucasb-eyer/pydensecrf.git /tmp/pip-req-build-7897peo7
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public

## Test on whole Images 

In [10]:
# split images 90/10 train validation (all dataset unique )
df_lab = pd.read_csv('/content/drive/My Drive/all_labeled.csv', index_col=0)
imfilename, area = [], []
for i in range(len(df_lab)):
    f = os.path.basename(df_lab.loc[i,'patch_filename']).replace(".JPG", "").split('_')
    imfilename.append('_'.join(f[1:3]) + ".JPG")
    area.append(f[0])
df_lab['area'] = area
df_lab['imfilename'] = imfilename
df_lab = df_lab.sort_values(by='imfilename').reset_index(drop=True)

Unnamed: 0,Unnamed: 0.1,index,patch_filename,plastic,area,imfilename
0,1874,2412.0,/home/giorgia/Desktop/MAI/Thesis/images/patche...,0.0,a1,DJI_0813.JPG
1,1971,2551.0,/home/giorgia/Desktop/MAI/Thesis/images/patche...,0.0,a1,DJI_0813.JPG
2,1948,2527.0,/home/giorgia/Desktop/MAI/Thesis/images/patche...,0.0,a1,DJI_0813.JPG
3,767,2465.0,/home/giorgia/Desktop/MAI/Thesis/images/patche...,1.0,a1,DJI_0813.JPG
4,1831,2363.0,/home/giorgia/Desktop/MAI/Thesis/images/patche...,0.0,a1,DJI_0813.JPG
...,...,...,...,...,...,...
7051,2583,4653.0,/home/giorgia/Desktop/MAI/Thesis/images/patche...,1.0,a2,DJI_0965.JPG
7052,2584,4654.0,/home/giorgia/Desktop/MAI/Thesis/images/patche...,1.0,a2,DJI_0965.JPG
7053,2585,4655.0,/home/giorgia/Desktop/MAI/Thesis/images/patche...,1.0,a2,DJI_0965.JPG
7054,2569,4639.0,/home/giorgia/Desktop/MAI/Thesis/images/patche...,0.0,a2,DJI_0965.JPG


In [15]:
import random 

imgs_path = '/content/drive/My Drive/selected_50m/'
imgs = np.unique(df_lab['imfilename'])
wholeimgs = []
for imf in imgs:
    for a in np.unique(df_lab['area']):
      if os.path.exists(os.path.join(imgs_path, a, imf)):
        wholeimgs.append(os.path.join(imgs_path, a, imf))
training_set = random.sample(wholeimgs, int(len(wholeimgs)*.85))
validation_set = list(set(wholeimgs) - set(training_set))
print("TRAINING SET: {}".format(training_set))
print("VALIDATION SET: {}".format(validation_set))

TRAINING SET: ['/content/drive/My Drive/selected_50m/a1/DJI_0913.JPG', '/content/drive/My Drive/selected_50m/a1/DJI_0853.JPG', '/content/drive/My Drive/selected_50m/a2/DJI_0846.JPG', '/content/drive/My Drive/selected_50m/a2/DJI_0900.JPG', '/content/drive/My Drive/selected_50m/a1/DJI_0813.JPG', '/content/drive/My Drive/selected_50m/a1/DJI_0821.JPG', '/content/drive/My Drive/selected_50m/a1/DJI_0828.JPG', '/content/drive/My Drive/selected_50m/a2/DJI_0965.JPG', '/content/drive/My Drive/selected_50m/a2/DJI_0818.JPG', '/content/drive/My Drive/selected_50m/a2/DJI_0869.JPG', '/content/drive/My Drive/selected_50m/a2/DJI_0929.JPG', '/content/drive/My Drive/selected_50m/a2/DJI_0938.JPG', '/content/drive/My Drive/selected_50m/a1/DJI_0860.JPG', '/content/drive/My Drive/selected_50m/a2/DJI_0908.JPG', '/content/drive/My Drive/selected_50m/a1/DJI_0881.JPG', '/content/drive/My Drive/selected_50m/a1/DJI_0869.JPG', '/content/drive/My Drive/selected_50m/a1/DJI_0884.JPG', '/content/drive/My Drive/selected

In [17]:
dest_path = '/datadrive/pytorch-data/'
train_path = os.path.join(dest_path,'litter_dataset', 'imgs', 'train/')
val_path = os.path.join(dest_path,'litter_dataset', 'imgs', 'val/')

In [26]:
import shutil
shutil.rmtree(train_path)
shutil.rmtree(val_path)

In [27]:
if not os.path.exists(train_path): os.makedirs(train_path)
if not os.path.exists(val_path): os.makedirs(val_path)

In [28]:
for t in training_set:
  filename = '_'.join(t.split('/')[-2:])
  os.symlink(t, os.path.join(train_path,filename))
  #shutil.copy(t, t.replace(imgs_path, train_path) )

for t in validation_set:
  filename = '_'.join(t.split('/')[-2:])
  os.symlink(t, os.path.join( val_path, filename))
  #shutil.copy(t, t.replace(imgs_path, val_path) )

# Test on Patches 

In [None]:
training_set = pd.read_csv('/content/drive/MyDrive/training_set.csv', index_col=0).loc[:,'patch_filename'].to_list()
validation_set = pd.read_csv('/content/drive/MyDrive/validation_set.csv', index_col=0).loc[:,'patch_filename'].to_list()

In [None]:
patches_path = '/content/drive/MyDrive/root_patches/patches/'
dest_path = '/datadrive/pytorch-data/'

train_path = os.path.join(dest_path,'litter_dataset', 'imgs', 'train/')
val_path = os.path.join(dest_path,'litter_dataset', 'imgs', 'val/')

training_set = [os.path.join(patches_path, os.path.basename(t)) for t in training_set]
validation_set = [os.path.join(patches_path, os.path.basename(t)) for t in validation_set]

In [None]:
#import shutil
#shutil.rmtree(train_path)
#shutil.rmtree(val_path)

In [None]:
if not os.path.exists(train_path): os.makedirs(train_path)
if not os.path.exists(val_path): os.makedirs(val_path)

In [None]:
for t in training_set:
  os.symlink(t, t.replace(patches_path, train_path))
  #shutil.copy(t, t.replace(patches_path, train_path) )

for t in validation_set:
  os.symlink(t, t.replace(patches_path, val_path))
  #shutil.copy(t, t.replace(patches_path, val_path) )


In [None]:
#import shutil
#shutil.rmtree(os.path.join(dest_path,'imgs'))

change these lines in STEGO/scr/config/train_confg.yaml 

```
dataset_name: "litter_dataset"
dir_dataset_name: "litter_dataset"
dir_dataset_n_classes: 5
```



#### Precompute indexes with KNN

In [32]:
import sys 
sys.path.append('/content/drive/MyDrive/STEGO/src/')
from data import ContrastiveSegDataset
from modules import *
import os
from os.path import join
import hydra
import numpy as np
import torch.multiprocessing
import torch.nn as nn
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.utilities.seed import seed_everything
from tqdm import tqdm

os.chdir('/content/drive/MyDrive/STEGO/src/')

def get_feats(model, loader):
    all_feats = []
    for pack in tqdm(loader):
        img = pack["img"]
        feats = F.normalize(model.forward(img.cuda()).mean([2, 3]), dim=1)
        all_feats.append(feats.to("cpu", non_blocking=True))
    return torch.cat(all_feats, dim=0).contiguous()


@hydra.main(config_path="configs", config_name="train_config.yml")
def my_app(cfg: DictConfig) -> None:
    print(OmegaConf.to_yaml(cfg))
    pytorch_data_dir = cfg.pytorch_data_dir
    data_dir = join(cfg.output_root, "data")
    log_dir = join(cfg.output_root, "logs")
    #data_dir = join("/content/drive/MyDrive/stego_train/", "data")
    #log_dir = join("/content/drive/MyDrive/stego_train/", "logs")
    os.makedirs(data_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(join(pytorch_data_dir, "nns"), exist_ok=True)

    seed_everything(seed=0)

    print(data_dir)
    print(cfg.output_root)

    image_sets = ["val", "train"]
    dataset_names = ["cocostuff27", "cityscapes", "potsdam"]
    crop_types = ["five", None]

    # Uncomment these lines to run on custom datasets
    dataset_names = ["directory"]
    crop_types = ["five", None]

    res = 256 #256
    n_batches = 16

    if cfg.arch == "dino":
        from modules import DinoFeaturizer, LambdaLayer
        no_ap_model = torch.nn.Sequential(
            DinoFeaturizer(20, cfg),  # dim doesent matter
            LambdaLayer(lambda p: p[0]),
        ).cuda()
    else:
        cut_model = load_model(cfg.model_type, join(cfg.output_root, "data")).cuda()
        no_ap_model = nn.Sequential(*list(cut_model.children())[:-1]).cuda()
    par_model = torch.nn.DataParallel(no_ap_model)

    for crop_type in crop_types:
        for image_set in image_sets:
            for dataset_name in dataset_names:
                nice_dataset_name = cfg.dir_dataset_name if dataset_name == "directory" else dataset_name

                feature_cache_file = join(pytorch_data_dir, "nns", "nns_{}_{}_{}_{}_{}.npz".format(
                    cfg.model_type, nice_dataset_name, image_set, crop_type, res))

                if not os.path.exists(feature_cache_file):
                    print("{} not found, computing".format(feature_cache_file))
                    dataset = ContrastiveSegDataset(
                        pytorch_data_dir=pytorch_data_dir,
                        dataset_name=dataset_name,
                        crop_type=crop_type,
                        image_set=image_set,
                        transform=get_transform(res, False, "center"),
                        target_transform=get_transform(res, True, "center"),
                        cfg=cfg,
                    )

                    loader = DataLoader(dataset, 256, shuffle=False, num_workers=cfg.num_workers, pin_memory=False)

                    with torch.no_grad():
                        normed_feats = get_feats(par_model, loader)
                        all_nns = []
                        step = normed_feats.shape[0] // n_batches
                        print(normed_feats.shape)
                        for i in tqdm(range(0, normed_feats.shape[0], step)):
                            torch.cuda.empty_cache()
                            batch_feats = normed_feats[i:i + step, :]
                            pairwise_sims = torch.einsum("nf,mf->nm", batch_feats, normed_feats)
                            all_nns.append(torch.topk(pairwise_sims, 30)[1])
                            del pairwise_sims
                        nearest_neighbors = torch.cat(all_nns, dim=0)

                        np.savez_compressed(feature_cache_file, nns=nearest_neighbors.numpy())
                        print("Saved NNs", cfg.model_type, nice_dataset_name, image_set)


if __name__ == "__main__":
    prep_args()
    my_app()


The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  "Support for .yml files is deprecated. Use .yaml extension for Hydra config files"
See https://hydra.cc/docs/next/upgrades/1.1_to_1.2/changes_to_job_working_dir/ for more information.
  configure_logging=with_log_configuration,
Global seed set to 0


output_root: ../
pytorch_data_dir: /datadrive/pytorch-data/
experiment_name: exp4
log_dir: litter_dataset
azureml_logging: false
submitting_to_aml: false
num_workers: 1
max_steps: 5000
batch_size: 8
num_neighbors: 7
dataset_name: directory
dir_dataset_name: litter_dataset
dir_dataset_n_classes: 5
has_labels: false
crop_type: five
crop_ratio: 0.5
res: 224
loader_crop_type: center
extra_clusters: 0
use_true_labels: false
use_recalibrator: false
model_type: vit_tiny
arch: dino
use_fit_model: false
dino_feat_type: feat
projection_type: nonlinear
dino_patch_size: 16
granularity: 1
continuous: true
dim: 70
dropout: true
zero_clamp: true
lr: 0.0005
pretrained_weights: /content/drive/MyDrive/output_dino/checkpoint0090.pth
use_salience: false
stabalize: false
stop_at_zero: true
pointwise: true
feature_samples: 11
neg_samples: 5
aug_alignment_weight: 0.0
correspondence_weight: 1.0
neg_inter_weight: 0.63
pos_inter_weight: 0.25
pos_intra_weight: 0.67
neg_inter_shift: 0.46
pos_inter_shift: 0.12
pos

  "Argument interpolation should be of type InterpolationMode instead of int. "


Pretrained weights found at /content/drive/MyDrive/output_dino/checkpoint0090.pth and loaded with msg: _IncompatibleKeys(missing_keys=[], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])
/datadrive/pytorch-data/nns/nns_vit_tiny_litter_dataset_val_five_256.npz not found, computing
/datadrive/pytorch-data/
litter_dataset


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
100%|██████████| 1/1 [00:01<00:00,  1.53s/it]

torch.Size([4, 192])



Error executing job with overrides: []
An error occurred during Hydra's exception formatting:
AssertionError()


ValueError: ignored

### Train Segmentation step

In [None]:
import sys
#sys.path.append('/content/drive/MyDrive/STEGO/src/')
from utils import *
from modules import *
from data import *
from torch.utils.data import DataLoader
import torch.nn.functional as F
from datetime import datetime
import hydra
from omegaconf import DictConfig, OmegaConf
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.seed import seed_everything
import torch.multiprocessing
import seaborn as sns
from pytorch_lightning.callbacks import ModelCheckpoint

#os.chdir('/content/drive/MyDrive/STEGO/src/')

torch.multiprocessing.set_sharing_strategy('file_system')

def get_class_labels(dataset_name):
    if dataset_name.startswith("cityscapes"):
        return [
            'road', 'sidewalk', 'parking', 'rail track', 'building',
            'wall', 'fence', 'guard rail', 'bridge', 'tunnel',
            'pole', 'polegroup', 'traffic light', 'traffic sign', 'vegetation',
            'terrain', 'sky', 'person', 'rider', 'car',
            'truck', 'bus', 'caravan', 'trailer', 'train',
            'motorcycle', 'bicycle']
    elif dataset_name == "cocostuff27":
        return [
            "electronic", "appliance", "food", "furniture", "indoor",
            "kitchen", "accessory", "animal", "outdoor", "person",
            "sports", "vehicle", "ceiling", "floor", "food",
            "furniture", "rawmaterial", "textile", "wall", "window",
            "building", "ground", "plant", "sky", "solid",
            "structural", "water"]
    elif dataset_name == "voc":
        return [
            'background',
            'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
            'bus', 'car', 'cat', 'chair', 'cow',
            'diningtable', 'dog', 'horse', 'motorbike', 'person',
            'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
    elif dataset_name == "potsdam":
        return [
            'roads and cars',
            'buildings and clutter',
            'trees and vegetation']
    ###########
    elif dataset_name == "litter_dataset":
      print("Defining classes for litter_dataset")
      return ["litter", "sand", "stones", "vegetation", "water"]
    ###########
    else:
        raise ValueError("Unknown Dataset {}".format(dataset_name))


class LitUnsupervisedSegmenter(pl.LightningModule):
    def __init__(self, n_classes, cfg):
        super().__init__()
        self.cfg = cfg
        self.n_classes = n_classes

        if not cfg.continuous:
            dim = n_classes
        else:
            dim = cfg.dim

        data_dir = join(cfg.output_root, "data")
        if cfg.arch == "feature-pyramid":
            cut_model = load_model(cfg.model_type, data_dir).cuda()
            self.net = FeaturePyramidNet(cfg.granularity, cut_model, dim, cfg.continuous)
        elif cfg.arch == "dino":
            self.net = DinoFeaturizer(dim, cfg)
        else:
            raise ValueError("Unknown arch {}".format(cfg.arch))

        self.train_cluster_probe = ClusterLookup(dim, n_classes)

        self.cluster_probe = ClusterLookup(dim, n_classes + cfg.extra_clusters)
        self.linear_probe = nn.Conv2d(dim, n_classes, (1, 1))

        self.decoder = nn.Conv2d(dim, self.net.n_feats, (1, 1))

        self.cluster_metrics = UnsupervisedMetrics(
            "test/cluster/", n_classes, cfg.extra_clusters, True)
        self.linear_metrics = UnsupervisedMetrics(
            "test/linear/", n_classes, 0, False)

        self.test_cluster_metrics = UnsupervisedMetrics(
            "final/cluster/", n_classes, cfg.extra_clusters, True)
        self.test_linear_metrics = UnsupervisedMetrics(
            "final/linear/", n_classes, 0, False)

        self.linear_probe_loss_fn = torch.nn.CrossEntropyLoss()
        self.crf_loss_fn = ContrastiveCRFLoss(
            cfg.crf_samples, cfg.alpha, cfg.beta, cfg.gamma, cfg.w1, cfg.w2, cfg.shift)

        self.contrastive_corr_loss_fn = ContrastiveCorrelationLoss(cfg)
        for p in self.contrastive_corr_loss_fn.parameters():
            p.requires_grad = False

        self.automatic_optimization = False

        if self.cfg.dataset_name.startswith("cityscapes"):
            self.label_cmap = create_cityscapes_colormap()
        else:
            self.label_cmap = create_pascal_label_colormap()

        self.val_steps = 0
        self.save_hyperparameters()

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        return self.net(x)[1]

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        net_optim, linear_probe_optim, cluster_probe_optim = self.optimizers()

        net_optim.zero_grad()
        linear_probe_optim.zero_grad()
        cluster_probe_optim.zero_grad()

        with torch.no_grad():
            ind = batch["ind"]
            img = batch["img"]
            img_aug = batch["img_aug"]
            coord_aug = batch["coord_aug"]
            img_pos = batch["img_pos"]
            label = batch["label"]
            label_pos = batch["label_pos"]

        feats, code = self.net(img)
        if self.cfg.correspondence_weight > 0:
            feats_pos, code_pos = self.net(img_pos)
        log_args = dict(sync_dist=False, rank_zero_only=True)

        if self.cfg.use_true_labels:
            signal = one_hot_feats(label + 1, self.n_classes + 1)
            signal_pos = one_hot_feats(label_pos + 1, self.n_classes + 1)
        else:
            signal = feats
            signal_pos = feats_pos

        loss = 0

        should_log_hist = (self.cfg.hist_freq is not None) and \
                          (self.global_step % self.cfg.hist_freq == 0) and \
                          (self.global_step > 0)
        if self.cfg.use_salience:
            salience = batch["mask"].to(torch.float32).squeeze(1)
            salience_pos = batch["mask_pos"].to(torch.float32).squeeze(1)
        else:
            salience = None
            salience_pos = None

        if self.cfg.correspondence_weight > 0:
            (
                pos_intra_loss, pos_intra_cd,
                pos_inter_loss, pos_inter_cd,
                neg_inter_loss, neg_inter_cd,
            ) = self.contrastive_corr_loss_fn(
                signal, signal_pos,
                salience, salience_pos,
                code, code_pos,
            )

            if should_log_hist:
                self.logger.experiment.add_histogram("intra_cd", pos_intra_cd, self.global_step)
                self.logger.experiment.add_histogram("inter_cd", pos_inter_cd, self.global_step)
                self.logger.experiment.add_histogram("neg_cd", neg_inter_cd, self.global_step)
            neg_inter_loss = neg_inter_loss.mean()
            pos_intra_loss = pos_intra_loss.mean()
            pos_inter_loss = pos_inter_loss.mean()
            self.log('loss/pos_intra', pos_intra_loss, **log_args)
            self.log('loss/pos_inter', pos_inter_loss, **log_args)
            self.log('loss/neg_inter', neg_inter_loss, **log_args)
            self.log('cd/pos_intra', pos_intra_cd.mean(), **log_args)
            self.log('cd/pos_inter', pos_inter_cd.mean(), **log_args)
            self.log('cd/neg_inter', neg_inter_cd.mean(), **log_args)

            loss += (self.cfg.pos_inter_weight * pos_inter_loss +
                     self.cfg.pos_intra_weight * pos_intra_loss +
                     self.cfg.neg_inter_weight * neg_inter_loss) * self.cfg.correspondence_weight

        if self.cfg.rec_weight > 0:
            rec_feats = self.decoder(code)
            rec_loss = -(norm(rec_feats) * norm(feats)).sum(1).mean()
            self.log('loss/rec', rec_loss, **log_args)
            loss += self.cfg.rec_weight * rec_loss

        if self.cfg.aug_alignment_weight > 0:
            orig_feats_aug, orig_code_aug = self.net(img_aug)
            downsampled_coord_aug = resize(
                coord_aug.permute(0, 3, 1, 2),
                orig_code_aug.shape[2]).permute(0, 2, 3, 1)
            aug_alignment = -torch.einsum(
                "bkhw,bkhw->bhw",
                norm(sample(code, downsampled_coord_aug)),
                norm(orig_code_aug)
            ).mean()
            self.log('loss/aug_alignment', aug_alignment, **log_args)
            loss += self.cfg.aug_alignment_weight * aug_alignment

        if self.cfg.crf_weight > 0:
            crf = self.crf_loss_fn(
                resize(img, 56),
                norm(resize(code, 56))
            ).mean()
            self.log('loss/crf', crf, **log_args)
            loss += self.cfg.crf_weight * crf

        flat_label = label.reshape(-1)
        mask = (flat_label >= 0) & (flat_label < self.n_classes)

        detached_code = torch.clone(code.detach())

        linear_logits = self.linear_probe(detached_code)
        linear_logits = F.interpolate(linear_logits, label.shape[-2:], mode='bilinear', align_corners=False)
        linear_logits = linear_logits.permute(0, 2, 3, 1).reshape(-1, self.n_classes)
        linear_loss = self.linear_probe_loss_fn(linear_logits[mask], flat_label[mask]).mean()
        loss += linear_loss
        self.log('loss/linear', linear_loss, **log_args)

        cluster_loss, cluster_probs = self.cluster_probe(detached_code, None)
        loss += cluster_loss
        self.log('loss/cluster', cluster_loss, **log_args)
        self.log('loss/total', loss, **log_args)

        self.manual_backward(loss)
        net_optim.step()
        cluster_probe_optim.step()
        linear_probe_optim.step()

        if self.cfg.reset_probe_steps is not None and self.global_step == self.cfg.reset_probe_steps:
            print("RESETTING PROBES")
            self.linear_probe.reset_parameters()
            self.cluster_probe.reset_parameters()
            self.trainer.optimizers[1] = torch.optim.Adam(list(self.linear_probe.parameters()), lr=5e-3)
            self.trainer.optimizers[2] = torch.optim.Adam(list(self.cluster_probe.parameters()), lr=5e-3)

        if self.global_step % 2000 == 0 and self.global_step > 0:
            print("RESETTING TFEVENT FILE")
            # Make a new tfevent file
            self.logger.experiment.close()
            self.logger.experiment._get_file_writer()

        return loss

    def on_train_start(self):
        tb_metrics = {
            **self.linear_metrics.compute(),
            **self.cluster_metrics.compute()
        }
        self.logger.log_hyperparams(self.cfg, tb_metrics)

    def validation_step(self, batch, batch_idx):
        img = batch["img"]
        label = batch["label"]
        self.net.eval()

        with torch.no_grad():
            feats, code = self.net(img)
            code = F.interpolate(code, label.shape[-2:], mode='bilinear', align_corners=False)

            linear_preds = self.linear_probe(code)
            linear_preds = linear_preds.argmax(1)
            self.linear_metrics.update(linear_preds, label)

            cluster_loss, cluster_preds = self.cluster_probe(code, None)
            cluster_preds = cluster_preds.argmax(1)
            self.cluster_metrics.update(cluster_preds, label)
            #self.cfg.n_images = 4 ##########################################
            return {
                'img': img[:self.cfg.n_images].detach().cpu(),
                'linear_preds': linear_preds[:self.cfg.n_images].detach().cpu(),
                "cluster_preds": cluster_preds[:self.cfg.n_images].detach().cpu(),
                "label": label[:self.cfg.n_images].detach().cpu()}

    def validation_epoch_end(self, outputs) -> None:
        super().validation_epoch_end(outputs)
        with torch.no_grad():
            tb_metrics = {
                **self.linear_metrics.compute(),
                **self.cluster_metrics.compute(),
            }

            if self.trainer.is_global_zero and not self.cfg.submitting_to_aml:
                #output_num = 0
                output_num = random.randint(0, len(outputs) -1)
                output = {k: v.detach().cpu() for k, v in outputs[output_num].items()}
                #print(len(output))
                fig, ax = plt.subplots(4, self.cfg.n_images, figsize=(self.cfg.n_images * 3, 4 * 3))
                for i in range(self.cfg.n_images):
                    ax[0, i].imshow(prep_for_plot(output["img"][i]))
                    ax[1, i].imshow(self.label_cmap[output["label"][i]])
                    ax[2, i].imshow(self.label_cmap[output["linear_preds"][i]])
                    ax[3, i].imshow(self.label_cmap[self.cluster_metrics.map_clusters(output["cluster_preds"][i])])
                ax[0, 0].set_ylabel("Image", fontsize=16)
                ax[1, 0].set_ylabel("Label", fontsize=16)
                ax[2, 0].set_ylabel("Linear Probe", fontsize=16)
                ax[3, 0].set_ylabel("Cluster Probe", fontsize=16)
                remove_axes(ax)
                plt.tight_layout()
                add_plot(self.logger.experiment, "plot_labels", self.global_step)

                if self.cfg.has_labels:
                    fig = plt.figure(figsize=(13, 10))
                    ax = fig.gca()
                    hist = self.cluster_metrics.histogram.detach().cpu().to(torch.float32)
                    hist /= torch.clamp_min(hist.sum(dim=0, keepdim=True), 1)
                    sns.heatmap(hist.t(), annot=False, fmt='g', ax=ax, cmap="Blues")
                    ax.set_xlabel('Predicted labels')
                    ax.set_ylabel('True labels')
                    names = get_class_labels(self.cfg.dataset_name)
                    if self.cfg.extra_clusters:
                        names = names + ["Extra"]
                    ax.set_xticks(np.arange(0, len(names)) + .5)
                    ax.set_yticks(np.arange(0, len(names)) + .5)
                    ax.xaxis.tick_top()
                    ax.xaxis.set_ticklabels(names, fontsize=14)
                    ax.yaxis.set_ticklabels(names, fontsize=14)
                    colors = [self.label_cmap[i] / 255.0 for i in range(len(names))]
                    [t.set_color(colors[i]) for i, t in enumerate(ax.xaxis.get_ticklabels())]
                    [t.set_color(colors[i]) for i, t in enumerate(ax.yaxis.get_ticklabels())]
                    # ax.yaxis.get_ticklabels()[-1].set_color(self.label_cmap[0] / 255.0)
                    # ax.xaxis.get_ticklabels()[-1].set_color(self.label_cmap[0] / 255.0)
                    plt.xticks(rotation=90)
                    plt.yticks(rotation=0)
                    ax.vlines(np.arange(0, len(names) + 1), color=[.5, .5, .5], *ax.get_xlim())
                    ax.hlines(np.arange(0, len(names) + 1), color=[.5, .5, .5], *ax.get_ylim())
                    plt.tight_layout()
                    add_plot(self.logger.experiment, "conf_matrix", self.global_step)

                    all_bars = torch.cat([
                        self.cluster_metrics.histogram.sum(0).cpu(),
                        self.cluster_metrics.histogram.sum(1).cpu()
                    ], axis=0)
                    ymin = max(all_bars.min() * .8, 1)
                    ymax = all_bars.max() * 1.2

                    fig, ax = plt.subplots(1, 2, figsize=(2 * 5, 1 * 4))
                    ax[0].bar(range(self.n_classes + self.cfg.extra_clusters),
                              self.cluster_metrics.histogram.sum(0).cpu(),
                              tick_label=names,
                              color=colors)
                    ax[0].set_ylim(ymin, ymax)
                    ax[0].set_title("Label Frequency")
                    ax[0].set_yscale('log')
                    ax[0].tick_params(axis='x', labelrotation=90)

                    ax[1].bar(range(self.n_classes + self.cfg.extra_clusters),
                              self.cluster_metrics.histogram.sum(1).cpu(),
                              tick_label=names,
                              color=colors)
                    ax[1].set_ylim(ymin, ymax)
                    ax[1].set_title("Cluster Frequency")
                    ax[1].set_yscale('log')
                    ax[1].tick_params(axis='x', labelrotation=90)

                    plt.tight_layout()
                    add_plot(self.logger.experiment, "label frequency", self.global_step)

            if self.global_step > 2:
                self.log_dict(tb_metrics)

                if self.trainer.is_global_zero and self.cfg.azureml_logging:
                    from azureml.core.run import Run
                    run_logger = Run.get_context()
                    for metric, value in tb_metrics.items():
                        run_logger.log(metric, value)

            self.linear_metrics.reset()
            self.cluster_metrics.reset()

    def configure_optimizers(self):
        main_params = list(self.net.parameters())

        if self.cfg.rec_weight > 0:
            main_params.extend(self.decoder.parameters())

        net_optim = torch.optim.Adam(main_params, lr=self.cfg.lr)
        linear_probe_optim = torch.optim.Adam(list(self.linear_probe.parameters()), lr=5e-3)
        cluster_probe_optim = torch.optim.Adam(list(self.cluster_probe.parameters()), lr=5e-3)

        return net_optim, linear_probe_optim, cluster_probe_optim


@hydra.main(config_path="configs", config_name="train_config.yml")
def my_app(cfg: DictConfig) -> None:
    OmegaConf.set_struct(cfg, False)
    print(OmegaConf.to_yaml(cfg))
    pytorch_data_dir = cfg.pytorch_data_dir
    data_dir = join(cfg.output_root, "data")
    log_dir = join(cfg.output_root, "logs")
    checkpoint_dir = join(cfg.output_root, "checkpoints")

    prefix = "{}/{}_{}".format(cfg.log_dir, cfg.dataset_name, cfg.experiment_name)
    name = '{}_date_{}'.format(prefix, datetime.now().strftime('%b%d_%H-%M-%S'))
    cfg.full_name = prefix

    os.makedirs(data_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(checkpoint_dir, exist_ok=True)

    seed_everything(seed=0)

    print(data_dir)
    print(cfg.output_root)

    geometric_transforms = T.Compose([
        T.RandomHorizontalFlip(),
        T.RandomResizedCrop(size=cfg.res, scale=(0.8, 1.0))
    ])
    photometric_transforms = T.Compose([
        T.ColorJitter(brightness=.3, contrast=.3, saturation=.3, hue=.1),
        T.RandomGrayscale(.2),
        T.RandomApply([T.GaussianBlur((5, 5))])
    ])

    sys.stdout.flush()

    train_dataset = ContrastiveSegDataset(
        pytorch_data_dir=pytorch_data_dir,
        dataset_name=cfg.dataset_name,
        crop_type=cfg.crop_type,
        image_set="train",
        transform=get_transform(cfg.res, False, cfg.loader_crop_type),
        target_transform=get_transform(cfg.res, True, cfg.loader_crop_type),
        cfg=cfg,
        aug_geometric_transform=geometric_transforms,
        aug_photometric_transform=photometric_transforms,
        num_neighbors=cfg.num_neighbors,
        mask=True,
        pos_images=True,
        pos_labels=True
    )

    if cfg.dataset_name == "voc":
        val_loader_crop = None
    else:
        val_loader_crop = "center"

    val_dataset = ContrastiveSegDataset(
        pytorch_data_dir=pytorch_data_dir,
        dataset_name=cfg.dataset_name,
        crop_type=None,
        image_set="val",
        transform=get_transform(320, False, val_loader_crop),
        target_transform=get_transform(320, True, val_loader_crop),
        mask=True,
        cfg=cfg,
    )

    #val_dataset = MaterializedDataset(val_dataset)
    train_loader = DataLoader(train_dataset, cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True)

    if cfg.submitting_to_aml:
        val_batch_size = 16
    else:
        val_batch_size = cfg.batch_size

    val_loader = DataLoader(val_dataset, val_batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)

    model = LitUnsupervisedSegmenter(train_dataset.n_classes, cfg)

    tb_logger = TensorBoardLogger(
        join(log_dir, name),
        default_hp_metric=False
    )

    if cfg.submitting_to_aml:
        gpu_args = dict(gpus=1, val_check_interval=250)

        if gpu_args["val_check_interval"] > len(train_loader):
            gpu_args.pop("val_check_interval")

    else:
        #gpu_args = dict(gpus=-1, accelerator='ddp', val_check_interval=cfg.val_freq)
        gpu_args = dict(gpus=-1, accelerator=None, val_check_interval=cfg.val_freq)
        # gpu_args = dict(gpus=1, accelerator='ddp', val_check_interval=cfg.val_freq)

        if gpu_args["val_check_interval"] > len(train_loader) // 4:
            gpu_args.pop("val_check_interval")

    trainer = Trainer(
        log_every_n_steps=cfg.scalar_log_freq,
        logger=tb_logger,
        max_steps=cfg.max_steps,
        callbacks=[
            ModelCheckpoint(
                dirpath=join(checkpoint_dir, name),
                every_n_train_steps=400,
                save_top_k=2,
                monitor="test/cluster/mIoU",
                mode="max",
            )
        ],
        **gpu_args
    )
    trainer.fit(model, train_loader, val_loader)


if __name__ == "__main__":
    prep_args()
    my_app()


The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  "Support for .yml files is deprecated. Use .yaml extension for Hydra config files"
See https://hydra.cc/docs/next/upgrades/1.1_to_1.2/changes_to_job_working_dir/ for more information.
  configure_logging=with_log_configuration,
Global seed set to 0


output_root: ../
pytorch_data_dir: /datadrive/pytorch-data/
experiment_name: exp3
log_dir: litter_dataset
azureml_logging: false
submitting_to_aml: false
num_workers: 1
max_steps: 5000
batch_size: 8
num_neighbors: 7
dataset_name: directory
dir_dataset_name: litter_dataset
dir_dataset_n_classes: 5
has_labels: false
crop_type: five
crop_ratio: 0.5
res: 128
loader_crop_type: center
extra_clusters: 0
use_true_labels: false
use_recalibrator: false
model_type: vit_tiny
arch: dino
use_fit_model: false
dino_feat_type: feat
projection_type: nonlinear
dino_patch_size: 16
granularity: 1
continuous: true
dim: 70
dropout: true
zero_clamp: true
lr: 0.0005
pretrained_weights: /content/drive/MyDrive/output_dino/checkpoint0090.pth
use_salience: false
stabalize: false
stop_at_zero: true
pointwise: true
feature_samples: 11
neg_samples: 5
aug_alignment_weight: 0.0
correspondence_weight: 1.0
neg_inter_weight: 0.63
pos_inter_weight: 0.25
pos_intra_weight: 0.67
neg_inter_shift: 0.46
pos_inter_shift: 0.12
pos

  "Argument interpolation should be of type InterpolationMode instead of int. "


Loading vit tiny pretrained Dino model


                not been set for this class (UnsupervisedMetrics). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Pretrained weights found at /content/drive/MyDrive/output_dino/checkpoint0090.pth and loaded with msg: _IncompatibleKeys(missing_keys=[], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])


Missing logger folder: ../logs/litter_dataset/directory_exp3_date_Jul09_14-45-58/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                     | Type                       | Params
-------------------------------------------------------------------------
0  | net                      | DinoFeaturizer             | 5.6 M 
1  | train_cluster_probe      | ClusterLookup              | 350   
2  | cluster_probe            | ClusterLookup              | 350   
3  | linear_probe             | Conv2d                     | 355   
4  | decoder                  | Conv2d                     | 13.6 K
5  | cluster_metrics          | UnsupervisedMetrics        | 0     
6  | linear_metrics           | UnsupervisedMetrics        | 0     
7  | test_cluster_metrics     | UnsupervisedMetrics        | 0     
8  | test_linear_metrics      | UnsupervisedMetrics        | 0     
9  | linear_probe_loss_fn     | CrossEntropyLoss           | 0     
10 | crf_loss_fn              | Contras

Sanity Checking: 0it [00:00, ?it/s]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Training: 0it [00:00, ?it/s]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Validation: 0it [00:00, ?it/s]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Validation: 0it [00:00, ?it/s]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Validation: 0it [00:00, ?it/s]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Validation: 0it [00:00, ?it/s]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Validation: 0it [00:00, ?it/s]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Validation: 0it [00:00, ?it/s]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Validation: 0it [00:00, ?it/s]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
