## Demo
This notebook is a demonstration of Non-Negative Positive Unlabeled Learning.

In [None]:
import os
import pprint

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from dotenv import load_dotenv
from sagemaker.pytorch import PyTorch
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms

from src.models import models
from src.trainer import execute_exp
from src.utils import config
from src.utils import visualizer

load_dotenv()

## Load config. file in this experiment

In [None]:
config_file = "./config/config.toml"
args = config.ConfigurationParser(config_file)

pprint.pprint(vars(args))

## Training and Evaluation.
The following code is executed training and evaluation

In [None]:
train_history_df, valid_history_df = execute_exp.repeated_trials(args=args, iteration=1)

### Run experiment on SageMaker Training Jobs

In [None]:
if args.use_sagemaker_training:
    base_job_name = str(os.environ['REPOSITORY_NAME'].split('/')[-1])
    print(base_job_name)

    estimator = PyTorch(
        source_dir="src",
        entry_point="main.py",
        dependencies=["src", "config", "outputs"],
        base_job_name=base_job_name,
        instance_type="ml.g4dn.xlarge",
        instance_count=1,
        checkpoint_s3_uri=f"s3://{os.environ['BUCKET_NAME']}/{base_job_name}/",
        # checkpoint_local_path="/app/outputs/models",
        image_uri=f"{os.environ['AWS_ACCOUNT_ID']}.dkr.ecr.{os.environ['AWS_REGION']}.amazonaws.com/{os.environ['REPOSITORY_NAME']}",
        role=f"arn:aws:iam::{os.environ['AWS_ACCOUNT_ID']}:role/{os.environ['SAGEMAKER_EXEC_ROLE_NAME']}",
        output_path=f"s3://{os.environ['BUCKET_NAME']}",
    )
    estimator.fit()

## Visualizing Training and Evaluation History

The following code show that the shift of each indicator (i.e. loss values, risk values and accuracy) on training and evaluation.

In [None]:
results = (train_history_df, valid_history_df)
visualizer.plot_history(results=results)

Visualizing inference using t-SNE

In [None]:
model = models.PositiveUnlabeledModel(
    in_features=args.in_features,
    hide_features=args.hide_features,
    out_features=args.out_features
)
model.load(model_save_dir="./outputs/models")
model.eval()

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
valid_dataset = datasets.PositiveUnlabeledMNIST(train=False, transform=transform)
valid_dataloader = DataLoader(valid_dataset, batch_size=64, shuffle=False)

predicts = np.array([])
with torch.no_grad():
    for idx, (feature, targets) in enumerate(valid_dataloader):
        output = model(feature)
        output = torch.where(output < 0, -1, 1)
        predict = output.view_as(targets).type(torch.float).numpy()
        predicts = np.concatenate([predicts, predict], axis=-1)

feature = valid_dataset.feature.numpy().reshape(10000, -1)
targets = valid_dataset.targets.numpy()

tsne = TSNE(n_components=2, random_state=42, init="pca", learning_rate="auto")
feature_reduced = tsne.fit_transform(feature)

# show scatter plot
sns.set(style="darkgrid")
sns.set_palette("bright")
plt.figure(figsize=(16, 8))
setting = {
    "target": [-1, 1],
    "label": ["unlabeled", "positive"],
    "title": ["Actual", "Predict"],
    "color": ["#52FFB8", "#3626A7"]
}

plt.figure(figsize=(16, 8))
for idx in range(2):
    actual_indices = (targets == setting["target"][idx])
    predict_indices = (predicts == setting["target"][idx])
    feature_actual = feature_reduced[actual_indices]
    feature_predict = feature_reduced[predict_indices]
    arrays = (feature_actual, feature_predict)

    for col, array in enumerate(arrays):
        plt.subplot(1, 2, col + 1)
        plt.scatter(
            x=array[:, 0],
            y=array[:, 1],
            color=setting["color"][idx],
            edgecolors="white",
            s=50,
            alpha=0.75,
            label=setting["label"][idx]
        )
        plt.legend()
        plt.gca().set_title(setting["title"][col])
plt.show()