# Experiment 4: Architecture Tests
This notebook tests the impact of varrying the neural network architecture in the active learning process when using zero-shot transfer learning. 

Network Architectures:
- ResNet
- ShuffleNet
- AlexNet
- DenseNet
- GoogLeNet
- MobileNetV2
- ResNeXt
- Wide ResNet

Datasets:
- OpenSARShip
- FUSAR-Ship

## TODO: 

1. Look through errors
2. This should be done up to the TODO comments
3. This runs. It takes a long time on my computer (even with M1). 
4. Results are saved. Ran this 12/27/22 at like 1am. 
5. It might be inneficient because the batch size is really large. It shouldn't have taken hours to run this. This didn't use the gpu. 
6. Should we do this without data augmentation so we can isolate this source of noise? This is without data augmentation. 

In [1]:
import numpy as np
import pandas as pd
import torch as th

import graphlearning as gl

import batch_active_learning as bal
import utils

#TODO: ENSURE CORRECT PARAMS in bal.coreset_run_experiment

## Parameters and Function Definitions

In [2]:
#Non-Default Parameters

num_experiments = 1

save_path = "Experiment Results/Experiment 4/"

In [3]:
# Does this use data augmentation? I believe so
# Zero-shot TL
def experiment4(dataset, network, num_experiments=num_experiments):
    assert dataset in utils.AVAILABLE_SAR_DATASETS, "Invalid dataset"
    assert dataset != "mstar", "Invalid dataset: not testing MSTAR"
    assert network in utils.PYTORCH_NEURAL_NETWORKS, "Invalid Neural Network"

    # Zero-Shot TL
    max_new_samples = bal.MAX_NEW_SAMPLES_DICT[dataset]
    X, labels = utils.encode_pretrained(dataset, network, transformed=False)
    knn_data = gl.weightmatrix.knnsearch(
        X, utils.KNN_NUM, method="annoy", similarity="angular"
    )

    if isinstance(X, th.Tensor):
        X = X.numpy()
    if isinstance(labels, th.Tensor):
        labels = labels.numpy()

    # Create graph objects
    W = gl.weightmatrix.knn(X, utils.KNN_NUM, kernel="gaussian", knn_data=knn_data)
    G = gl.graph(W)

    num_iter = max_new_samples // bal.BATCH_SIZE
    acc_results = np.zeros(num_experiments)

    for i in range(num_experiments):
        # Ensure each label is represented in core set
        initial = gl.trainsets.generate(labels, rate=1).tolist()

        coreset = bal.coreset_dijkstras(
            G,
            rad=bal.DENSITY_RADIUS,
            data=X,
            initial=initial,
            density_info=(True, bal.DENSITY_RADIUS, 1),
            knn_data=knn_data,
        )

        _, num_labels, acc_vals, _ = bal.coreset_run_experiment(
            X,
            labels,
            W,
            coreset,
            num_iter=num_iter,
            method="Laplace",
            display=False,
            use_prior=False,
            al_mtd="local_max",
            acq_fun="uc",
            knn_data=knn_data,
            mtd_para=None,
            savefig=False,
            batchsize=bal.BATCH_SIZE,
            dist_metric="angular",
            knn_size=utils.KNN_NUM,
            q=1,
            thresholding=0,
        )

        acc_results[i] = acc_vals[-1]
    end_labels = num_labels[-1]

    return np.mean(acc_results)


## Experiments

In [4]:
results_dict = {'open_sar_ship': {}, 'fusar': {}}

for dataset in utils.AVAILABLE_SAR_DATASETS[1:]:
    for network in utils.PYTORCH_NEURAL_NETWORKS:
        print(dataset + "_" + network)
        results_dict[dataset][network] = experiment4(dataset, network, num_experiments=num_experiments)


Using cache found in /Users/jameschapman/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/jameschapman/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/jameschapman/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/jameschapman/.cache/torch/hub/pytorch_vision_v0.10.0
Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /Users/jameschapman/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth


  0%|          | 0.00/30.8M [00:00<?, ?B/s]

Using cache found in /Users/jameschapman/.cache/torch/hub/pytorch_vision_v0.10.0
Downloading: "https://download.pytorch.org/models/googlenet-1378be20.pth" to /Users/jameschapman/.cache/torch/hub/checkpoints/googlenet-1378be20.pth


  0%|          | 0.00/49.7M [00:00<?, ?B/s]

Using cache found in /Users/jameschapman/.cache/torch/hub/pytorch_vision_v0.10.0
Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /Users/jameschapman/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth


  0%|          | 0.00/13.6M [00:00<?, ?B/s]

Using cache found in /Users/jameschapman/.cache/torch/hub/pytorch_vision_v0.10.0
Downloading: "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth" to /Users/jameschapman/.cache/torch/hub/checkpoints/resnext50_32x4d-7cdf4587.pth


  0%|          | 0.00/95.8M [00:00<?, ?B/s]

Using cache found in /Users/jameschapman/.cache/torch/hub/pytorch_vision_v0.10.0
Downloading: "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth" to /Users/jameschapman/.cache/torch/hub/checkpoints/wide_resnet50_2-95faca4d.pth


  0%|          | 0.00/132M [00:00<?, ?B/s]

Using cache found in /Users/jameschapman/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/jameschapman/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/jameschapman/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/jameschapman/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/jameschapman/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/jameschapman/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/jameschapman/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/jameschapman/.cache/torch/hub/pytorch_vision_v0.10.0


In [5]:
df = pd.DataFrame.from_dict(results_dict, orient='index')
df.to_pickle(save_path + 'results_' + str(num_experiments) + '.pkl')

print(df)

                  ResNet  ShuffleNet    AlexNet   DenseNet  GoogLeNet  \
open_sar_ship  69.208038   68.004398  80.881400  72.511848  74.486415   
fusar          85.655738   86.374269  82.896305  88.940901  85.046729   

               MobileNetV2    ResNeXt  Wide ResNet  
open_sar_ship    72.225666  74.035517    72.229141  
fusar            85.058824  88.266509    85.462036  
