<a href="https://colab.research.google.com/github/kashperova/ssl-hsi-course-work/blob/main/notebooks/ssl_conformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive

drive.mount("/content/gdrive")

Mounted at /content/gdrive


In [1]:
!git clone https://@github.com/kashperova/ssl-hsi-course-work.git

Cloning into 'ssl-hsi-course-work'...
remote: Enumerating objects: 87, done.[K
remote: Counting objects: 100% (36/36), done.[K
remote: Compressing objects: 100% (30/30), done.[K
remote: Total 87 (delta 13), reused 14 (delta 6), pack-reused 51 (from 3)[K
Receiving objects: 100% (87/87), 93.21 MiB | 15.59 MiB/s, done.
Resolving deltas: 100% (14/14), done.


In [1]:
%cd ssl-hsi-course-work/src

/content/ssl-hsi-course-work/src


In [3]:
!curl -LsSf https://astral.sh/uv/install.sh | sh

downloading uv 0.7.9 x86_64-unknown-linux-gnu
no checksums to verify
installing to /usr/local/bin
  uv
  uvx
everything's installed!


In [None]:
!uv pip install --system scikit-learn==1.6.0
!uv pip install --system seaborn==0.13.0
!uv pip install --system matplotlib==3.10.1
!uv pip install --system plotly==6.0.0
!uv pip install --system torchmetrics==1.7.2
!uv pip install --system scipy==1.15.2

In [2]:
from models.conformer.model import ModModel
from modules.trainers.noisy_student import NoisyStudentTrainer
from modules.datasets.hsi import HyperspectralDataset
from config.train_config import BaseTrainConfig
from utils.seed import set_seed
from utils.metrics import Metrics, Task
from utils.data import load_hsi_dataset, get_stratified_subset

import os
import wandb
import torch.optim as optim
from torch.utils.data import Subset, Dataset
from sklearn.model_selection import train_test_split

from torch import nn

In [6]:
os.environ["WANDB_API_KEY"] = ""

In [3]:
wandb.init()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mskashperova[0m ([33mkashperova-test[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
set_seed(42)

Random seed set to 42


In [5]:
class TeacherConfig(BaseTrainConfig):
    epochs: int = 50
    train_batch_size: int = 64
    eval_batch_size: int = 64
    train_test_split: float = 0.7


class StudentConfig(BaseTrainConfig):
    epochs: int = 50
    train_batch_size: int = 64
    eval_batch_size: int = 64
    train_test_split: float = 0.7

In [7]:
def load_data(dataset_name: str, test_size: float = 0.8):
    patches, labels = load_hsi_dataset(
        dataset_name=dataset_name, root_dir="../data", pca_components=30, patch_size=15
    )
    dataset = HyperspectralDataset(patches, labels)
    targets = [dataset[i][1] for i in range(len(dataset))]
    train_indices, val_indices = train_test_split(
        range(len(dataset)),
        test_size=test_size,
        stratify=targets,
    )
    labeled_dataset = Subset(dataset, train_indices)
    unlabeled = Subset(dataset, val_indices)
    unlabeled_dataset = UnlabeledDataset(unlabeled)

    return labeled_dataset, unlabeled_dataset

In [10]:
def train(labeled, unlabeled, num_classes: int, save_dir: str):
    teacher_model = ModModel(in_channels=30, num_classes=num_classes)
    student_model = ModModel(in_channels=30, num_classes=num_classes)

    teacher_optimizer = optim.Adam(teacher_model.parameters(), lr=1e-4)
    student_optimizer = optim.Adam(student_model.parameters(), lr=1e-4)

    teacher_lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        teacher_optimizer, patience=10, factor=0.05
    )
    student_lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        student_optimizer, patience=10, factor=0.05
    )
    criterion = nn.CrossEntropyLoss()

    ns_trainer = NoisyStudentTrainer(
        teacher_model=teacher_model,
        student_model=student_model,
        loss_fn=criterion,
        teacher_optimizer=teacher_optimizer,
        student_optimizer=student_optimizer,
        labeled_dataset=labeled,
        unlabeled_dataset=unlabeled,
        teacher_lr_scheduler=teacher_lr_scheduler,
        student_lr_scheduler=student_lr_scheduler,
        teacher_config=TeacherConfig(),
        student_config=StudentConfig(),
        metrics=Metrics(
            task=Task.MULTICLASS_CLASSIFICATION,
            num_classes=num_classes,
            average="micro",
        ),
        # save_dir=save_dir
    )
    ns_trainer.train_teacher(verbose=False)
    ns_trainer.train_student(verbose=False)


def reinit_wandb():
    wandb.finish(quiet=True)
    wandb.init()

In [45]:
labeled, unlabeled = load_data("IndianPines")
train(labeled, unlabeled, 16, "conformer_ssl_20_ip")

Training: 100%|██████████| 50/50 [00:56<00:00,  1.13s/it]


Pseudo Label Annotation


Training: 100%|██████████| 50/50 [04:21<00:00,  5.22s/it]


In [None]:
reinit_wandb()

In [48]:
labeled, unlabeled = load_data("IndianPines", test_size=0.9)
train(labeled, unlabeled, 16, "conformer_ssl_10_ip")

Training: 100%|██████████| 50/50 [00:29<00:00,  1.67it/s]


Pseudo Label Annotation


Training: 100%|██████████| 50/50 [03:58<00:00,  4.77s/it]


In [49]:
reinit_wandb()

In [50]:
labeled, unlabeled = load_data("PaviaUniversity")
train(labeled, unlabeled, 9, "conformer_ssl_20_pu")

Training: 100%|██████████| 50/50 [03:51<00:00,  4.64s/it]


Pseudo Label Annotation


Training: 100%|██████████| 50/50 [18:55<00:00, 22.70s/it]


In [None]:
reinit_wandb()

In [11]:
labeled, unlabeled = load_data("PaviaUniversity", test_size=0.9)
train(labeled, unlabeled, 9, "conformer_ssl_10_pu")

Training: 100%|██████████| 50/50 [02:02<00:00,  2.45s/it]


Pseudo Label Annotation


Training: 100%|██████████| 50/50 [19:39<00:00, 23.59s/it]


In [None]:
reinit_wandb()

In [13]:
labeled, unlabeled = load_data("KSC")
train(labeled, unlabeled, 13, "conformer_ssl_20_ksc")

Training: 100%|██████████| 50/50 [00:37<00:00,  1.32it/s]


Pseudo Label Annotation


Training: 100%|██████████| 50/50 [01:54<00:00,  2.29s/it]


In [None]:
reinit_wandb()

In [None]:
labeled, unlabeled = load_data("KSC", test_size=0.9)
train(labeled, unlabeled, 13, "conformer_ssl_10_ksc")