In [None]:
import sys
from pathlib import Path

In [None]:
import numpy as np
import tensorflow as tf

from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle

In [None]:
import tensorflow_datasets as tfds

In [None]:
sys.path.append(str(Path("../../ssl").resolve()))

In [None]:
from src.models.pi_model.pi_model import PiModel
from src.models.pi_model.pi_model_config import PiModelConfig
from src.trainers.pseudo_label.pseudo_label import PseudoLabelTrainer
from src.trainers.pseudo_label.pseudo_label_config import PseudoLabelTrainerConfig
from src.data_loaders.pseudo_label.pseudo_label import PseudoLabelDataLoader
from src.data_loaders.pseudo_label.pseudo_label import PseudoLabelDataLoaderConfig

# Description

In this notebook, the AlexNet model will be trained on the CIFAR-10 dataset using only 25% of the labelled dataset.

## Set up Experiment

In [None]:
class TrainerConfig(PseudoLabelTrainerConfig):
    num_epochs = 200
    t1 = 50
    t2 = 100
    alpha = 5.0

train_config = TrainerConfig()

In [None]:
class ModelConfig(PiModelConfig):
    input_shape = (96, 96, 3)
    output_shape = 10

model_config = ModelConfig()

In [None]:
class DataLoaderConfig(PseudoLabelDataLoaderConfig):
    batch_size = 64
    num_classes = 10
    shuffle_buffer_size = 105000 # dataset size    

data_loader_config = DataLoaderConfig()

## Get Datasets

In [None]:
unlabeled_train_dataset = (
    tfds.load("stl10", split="unlabelled", as_supervised=True)
)
labeled_train_dataset = (
    tfds.load("stl10", split="train", as_supervised=True)
)
test_dataset = (
    tfds.load("stl10", split="test", as_supervised=True)
)

In [None]:
train_dataset = labeled_train_dataset.concatenate(unlabeled_train_dataset)

In [None]:
# create train dataset
train_data = PseudoLabelDataLoader(train_dataset, data_loader_config)(training = True)

In [None]:
# create test dataset
val_data = PseudoLabelDataLoader(test_dataset, data_loader_config)(training = False)

In [None]:
print(f"Train dataset size: {train_data.cardinality()}")
print(f"Validation dataset size: {val_data.cardinality()}")

# Train Model

In [None]:
model = PiModel(model_config)()

In [None]:
trainer = PseudoLabelTrainer(
    model, train_data, train_config,
    val_dataset = val_data)

In [None]:
trainer.train()