In [1]:
import os
from os.path import join
from itertools import chain
import argparse
import yaml

import pandas as pd
import numpy as np

import sklearn
import sklearn.model_selection
import sklearn.preprocessing
import sklearn.ensemble
import sklearn.manifold
import sklearn.linear_model
import sklearn.svm

import torch
import torch.utils.data

import dredda.model as model

import dredda.data as data
import dredda.test as test
import dredda.model as model
import dredda.train as train
from dredda.cloud_files import remote_files_manifest, remote_files_checksum, download_if_not_exist,local_prefix
from dredda.helpers import seed_all
os.chdir("../")

  from .autonotebook import tqdm as notebook_tqdm


# Training with a dummy example

Set and save arguments

In [2]:
from argparse import Namespace
args=Namespace()
args.source_dataset_name="source_GEP"
args.target_dataset_name="target_GEP"
args.out_dir="train_dir/notebook-train"
args.n_epochs=100
args.dual_training_epoch=50
args.domain_adv_coeff=0.1
args.ddc_coeff=0.1
args.seed=41

In [3]:
if os.path.isdir(args.out_dir):
    raise ValueError(f"{args.out_dir} already exists")
else:
    os.makedirs(args.out_dir)

# save args
with open(join(args.out_dir, "train--args.yaml"), "x") as f:
    yaml.dump(args.__dict__, f)

Load a dummy example

In [4]:
def load_data():
    n_samples=1000
    n_features=1000
    X_source=np.random.randn(n_samples,n_features)
    Y_source=np.random.randint(low=0,high=4,size=n_samples)
    X_target=np.random.randn(n_samples,n_features)
    Y_target=None
    return X_source, Y_source, X_target, Y_target

Train test split

In [5]:
# load data
seed_all(args.seed)
X_source, Y_source, X_target, Y_target = load_data()
source_dataset_name, target_dataset_name = args.source_dataset_name, args.target_dataset_name
X_source_train, X_source_test, Y_source_train, Y_source_test = (
    sklearn.model_selection.train_test_split(
        X_source, Y_source, test_size=0.2, random_state=123
    )
)

Create model

In [6]:
# train
net = model.FCModelDualBranchAE(n_in_features=X_source.shape[1])
net = net.cuda()
trainer = train.DualBranchDATrainer(
    net,
    source_dataset_name,
    target_dataset_name,
    net.ae_encoder_t.parameters(),
    chain(
        net.ae_encoder_s.parameters(),
        net.ae_decoder.parameters(),
        net.feature.parameters(),
        net.class_classifier.parameters(),
        net.domain_classifier.parameters(),
    ),
    n_epochs=args.n_epochs,
    domain_adv_coeff=args.domain_adv_coeff,
    ddc_coeff=args.ddc_coeff,
    dual_training_epoch=args.dual_training_epoch,
    save_root=args.out_dir,
)

Fit model

In [7]:
trainer.fit(
    X_source_train,
    Y_source_train,
    X_target,
    Y_target,
    X_source_val=X_source_test,
    Y_source_val=Y_source_test,
    save=True,
)

Length of dataloaders source: 7 target: 8
Parameters: alpha = 1.0000, domain_adv_coeff = 0.1000, ddc_coeff = 0.1000, ddc_features = c_fc2
Before training:
	Accuracy on source_GEP dataset: 0.242500
	Confusion matrix:
[[194   0   0   0]
 [201   0   0   0]
 [210   0   0   0]
 [195   0   0   0]]

Epoch 1:
	Cumulative metrics:
		domain_loss_t_domain: 0.08156387380191259
		domain_loss_s_domain: 0.058403026206152786
		class_loss_s_domain: 1.39454185962677
		ddc: 4.3295491001872835e-09
		loss: 1.5345087562288557
	On source train set:
		class_accuracy: 0.2425
		class_loss: 1.394539475440979
		domain_loss: 0.058436132967472076
		domain_accuracy: 1.0
		confusion matrix:
[[194   0   0   0]
 [201   0   0   0]
 [210   0   0   0]
 [195   0   0   0]]

	On source val set:
		class_accuracy: 0.285
		class_loss: 1.3877177238464355
		domain_loss: 0.058426856994628906
		domain_accuracy: 1.0
		confusion matrix:
[[57  0  0  0]
 [55  0  0  0]
 [41  0  0  0]
 [47  0  0  0]]

	On target set:
		domain_loss: 0.081


Epoch 13:
	Cumulative metrics:
		domain_loss_t_domain: 0.0829034413610186
		domain_loss_s_domain: 0.05801101156643458
		class_loss_s_domain: 1.3734887838363647
		ddc: 3.6805074030001256e-05
		loss: 1.514440025602068
	On source train set:
		class_accuracy: 0.2425
		class_loss: 1.3699771165847778
		domain_loss: 0.057727713137865067
		domain_accuracy: 1.0
		confusion matrix:
[[194   0   0   0]
 [201   0   0   0]
 [210   0   0   0]
 [195   0   0   0]]

	On source val set:
		class_accuracy: 0.285
		class_loss: 1.3878774642944336
		domain_loss: 0.05720660835504532
		domain_accuracy: 1.0
		confusion matrix:
[[57  0  0  0]
 [55  0  0  0]
 [41  0  0  0]
 [47  0  0  0]]

	On target set:
		domain_loss: 0.08338465541601181
		domain_accuracy: 0.0

Epoch 14:
	Cumulative metrics:
		domain_loss_t_domain: 0.08378557647977557
		domain_loss_s_domain: 0.05745434165000916
		class_loss_s_domain: 1.3659070049013409
		ddc: 4.7324290374360444e-05
		loss: 1.5071942635944913
	On source train set:
		class_accura

		domain_loss: 0.06827475875616074
		domain_accuracy: 0.701

Epoch 26:
	Cumulative metrics:
		domain_loss_t_domain: 0.06660089833395823
		domain_loss_s_domain: 0.06856545635632107
		class_loss_s_domain: 0.5273972494261605
		ddc: 0.10540093247379576
		loss: 0.767964541912079
	On source train set:
		class_accuracy: 0.9725
		class_loss: 0.4963904619216919
		domain_loss: 0.06903361529111862
		domain_accuracy: 0.355
		confusion matrix:
[[177   1   7   9]
 [  0 201   0   0]
 [  0   0 210   0]
 [  0   0   5 190]]

	On source val set:
		class_accuracy: 0.255
		class_loss: 1.945643663406372
		domain_loss: 0.07490503042936325
		domain_accuracy: 0.245
		confusion matrix:
[[18  9 13 17]
 [21 10  7 17]
 [14  7  4 16]
 [14 12  2 19]]

	On target set:
		domain_loss: 0.06615819782018661
		domain_accuracy: 0.718

Epoch 27:
	Cumulative metrics:
		domain_loss_t_domain: 0.06685993586267744
		domain_loss_s_domain: 0.06786582895687648
		class_loss_s_domain: 0.4843233738626753
		ddc: 0.0874116810304778
		los


Epoch 37:
	Cumulative metrics:
		domain_loss_t_domain: 0.0691207298210689
		domain_loss_s_domain: 0.0680694682257516
		class_loss_s_domain: 0.19293057918548587
		ddc: 0.266165880220277
		loss: 0.5962866757597243
	On source train set:
		class_accuracy: 0.99
		class_loss: 0.19308429956436157
		domain_loss: 0.06937529146671295
		domain_accuracy: 0.51875
		confusion matrix:
[[189   0   3   2]
 [  0 201   0   0]
 [  0   0 210   0]
 [  0   0   3 192]]

	On source val set:
		class_accuracy: 0.27
		class_loss: 2.938016653060913
		domain_loss: 0.0792059525847435
		domain_accuracy: 0.255
		confusion matrix:
[[23  6 14 14]
 [25  8  7 15]
 [16  6  5 14]
 [15 11  3 18]]

	On target set:
		domain_loss: 0.06770656257867813
		domain_accuracy: 0.532

Epoch 38:
	Cumulative metrics:
		domain_loss_t_domain: 0.06783657925469536
		domain_loss_s_domain: 0.0687754009451185
		class_loss_s_domain: 0.18954987611089433
		ddc: 0.269077114548002
		loss: 0.5952389793736594
	On source train set:
		class_accuracy: 0.


Epoch 49:
	Cumulative metrics:
		domain_loss_t_domain: 0.06701391765049525
		domain_loss_s_domain: 0.068208532673972
		class_loss_s_domain: 0.12968379046235767
		ddc: 0.47999552062579565
		loss: 0.7449017763137817
	On source train set:
		class_accuracy: 0.99375
		class_loss: 0.12725670635700226
		domain_loss: 0.06905113905668259
		domain_accuracy: 0.52625
		confusion matrix:
[[191   0   1   2]
 [  0 201   0   0]
 [  0   0 210   0]
 [  0   0   2 193]]

	On source val set:
		class_accuracy: 0.295
		class_loss: 3.2684781551361084
		domain_loss: 0.07635625451803207
		domain_accuracy: 0.24
		confusion matrix:
[[22  6 11 18]
 [23  8  7 17]
 [16  6  5 14]
 [11 11  1 24]]

	On target set:
		domain_loss: 0.06602729111909866
		domain_accuracy: 0.627

Epoch 50:
	Cumulative metrics:
		domain_loss_t_domain: 0.06681192346981593
		domain_loss_s_domain: 0.06865784696170261
		class_loss_s_domain: 0.1300419494509697
		ddc: 0.3028193141732897
		loss: 0.5683310372488839
	On source train set:
		class_accu


Epoch 61:
	Cumulative metrics:
		domain_loss_t_domain: 0.06562812498637609
		domain_loss_s_domain: 0.06922444701194763
		class_loss_s_domain: 0.10516545815127236
		ddc: 0.36297729089856146
		loss: 0.6029953360557556
	On source train set:
		class_accuracy: 0.99625
		class_loss: 0.09925999492406845
		domain_loss: 0.06465563178062439
		domain_accuracy: 0.7525
		confusion matrix:
[[193   0   1   0]
 [  0 201   0   0]
 [  0   0 210   0]
 [  0   0   2 193]]

	On source val set:
		class_accuracy: 0.265
		class_loss: 3.3354272842407227
		domain_loss: 0.071814626455307
		domain_accuracy: 0.48
		confusion matrix:
[[22  6 13 16]
 [24  7  7 17]
 [16  6  5 14]
 [13 11  4 19]]

	On target set:
		domain_loss: 0.06972222030162811
		domain_accuracy: 0.417

Epoch 62:
	Cumulative metrics:
		domain_loss_t_domain: 0.0703314551285335
		domain_loss_s_domain: 0.06426148159163339
		class_loss_s_domain: 0.09902945160865784
		ddc: 0.2363294277872358
		loss: 0.46995183399745394
	On source train set:
		class_accu


Epoch 73:
	Cumulative metrics:
		domain_loss_t_domain: 0.06977042300360543
		domain_loss_s_domain: 0.06434024912970407
		class_loss_s_domain: 0.09495015974555697
		ddc: 0.5046770351273673
		loss: 0.7337379114968436
	On source train set:
		class_accuracy: 0.99625
		class_loss: 0.09514123946428299
		domain_loss: 0.0659060999751091
		domain_accuracy: 0.7525
		confusion matrix:
[[193   0   1   0]
 [  0 201   0   0]
 [  0   0 210   0]
 [  0   0   2 193]]

	On source val set:
		class_accuracy: 0.275
		class_loss: 3.351072072982788
		domain_loss: 0.07336049526929855
		domain_accuracy: 0.36
		confusion matrix:
[[23  5 13 16]
 [23  7  7 18]
 [18  5  5 13]
 [12 11  4 20]]

	On target set:
		domain_loss: 0.06810002028942108
		domain_accuracy: 0.519

Epoch 74:
	Cumulative metrics:
		domain_loss_t_domain: 0.06662303805351256
		domain_loss_s_domain: 0.06770761779376439
		class_loss_s_domain: 0.09725732249873026
		ddc: 0.36333371241177836
		loss: 0.5949217017207826
	On source train set:
		class_accu


Epoch 85:
	Cumulative metrics:
		domain_loss_t_domain: 0.06566302946635656
		domain_loss_s_domain: 0.06791470817157201
		class_loss_s_domain: 0.06534940696188382
		ddc: 0.6694291012627738
		loss: 0.8683562534196037
	On source train set:
		class_accuracy: 0.9975
		class_loss: 0.07110226899385452
		domain_loss: 0.06759844720363617
		domain_accuracy: 0.72625
		confusion matrix:
[[193   0   1   0]
 [  0 201   0   0]
 [  0   0 210   0]
 [  0   0   1 194]]

	On source val set:
		class_accuracy: 0.285
		class_loss: 3.597501516342163
		domain_loss: 0.07599630951881409
		domain_accuracy: 0.285
		confusion matrix:
[[24  6 11 16]
 [23  8  7 17]
 [17  6  5 13]
 [13 12  2 20]]

	On target set:
		domain_loss: 0.06636165082454681
		domain_accuracy: 0.653

Epoch 86:
	Cumulative metrics:
		domain_loss_t_domain: 0.06828998753002712
		domain_loss_s_domain: 0.06497895121574403
		class_loss_s_domain: 0.07084875819938524
		ddc: 0.27031293229333
		loss: 0.47443064408642904
	On source train set:
		class_accu


Epoch 97:
	Cumulative metrics:
		domain_loss_t_domain: 0.06546641758510045
		domain_loss_s_domain: 0.06747242893491473
		class_loss_s_domain: 0.08464434423616955
		ddc: 0.15431557212557112
		loss: 0.37189876607486183
	On source train set:
		class_accuracy: 0.99875
		class_loss: 0.08150321245193481
		domain_loss: 0.0654088482260704
		domain_accuracy: 0.75125
		confusion matrix:
[[194   0   0   0]
 [  0 201   0   0]
 [  0   0 210   0]
 [  0   0   1 194]]

	On source val set:
		class_accuracy: 0.265
		class_loss: 3.244053840637207
		domain_loss: 0.07307589799165726
		domain_accuracy: 0.36
		confusion matrix:
[[21  5 13 18]
 [23  7  7 18]
 [17  5  5 14]
 [13 10  4 20]]

	On target set:
		domain_loss: 0.0672459825873375
		domain_accuracy: 0.61

Epoch 98:
	Cumulative metrics:
		domain_loss_t_domain: 0.06820615444864546
		domain_loss_s_domain: 0.06481829711369107
		class_loss_s_domain: 0.08165366521903446
		ddc: 0.26475440263748173
		loss: 0.4794325296367917
	On source train set:
		class_acc

To train the model on the full dataset, please use `python -m dredda train`