In [1]:
import shlex
from tqdm.notebook import tqdm
from IPython.display import display

import torch
import torch.nn as nn
import copy
from tqdm import tqdm_notebook
from torch.utils.data import ConcatDataset
from GTTA.data.good_loaders.fast_data_loader import InfiniteDataLoader, FastDataLoader
from GTTA.utils.register import register
import pandas as pd
import numpy as np

from GTTA import config_summoner
from GTTA.kernel.main import initialize_model_dataset
from GTTA.utils.args import args_parser
from GTTA.utils.load_manager import load_atta_algorithm, load_alg
from GTTA.utils.load_manager import load_tta_algorithm
from GTTA.utils.logger import load_logger

argv = shlex.split('--task train --config_path TTA_configs/PACS/base.yaml --gpu_idx 6 --exp_round 1')
print(argv)
args = args_parser(argv)
config = config_summoner(args)
load_logger(config)

alg = load_atta_algorithm(config)

D_S = alg.dataset[0]
D_T = alg.target_dataset



['--task', 'train', '--config_path', 'TTA_configs/PACS/base.yaml', '--gpu_idx', '6', '--exp_round', '1']




This logger will substitute general print function


Process Process-184:
Traceback (most recent call last):
  File "/data/shurui.gui/anaconda3/envs/torch_110/lib/python3.8/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
  File "/data/shurui.gui/anaconda3/envs/torch_110/lib/python3.8/multiprocessing/util.py", line 360, in _exit_function
    _run_finalizers()
  File "/data/shurui.gui/anaconda3/envs/torch_110/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
KeyboardInterrupt
  File "/data/shurui.gui/anaconda3/envs/torch_110/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/data/shurui.gui/anaconda3/envs/torch_110/lib/python3.8/multiprocessing/queues.py", line 195, in _finalize_join
    thread.join()
  File "/data/shurui.gui/anaconda3/envs/torch_110/lib/python3.8/threading.py", line 1011, in join
    self._wait_for_tstate_lock()
  File "/data/shurui.gui/anaconda3/envs/torch_110/lib/python3.8/threading

## Experiment 5: Whether high/low entropy corresponds to target/source domains?

* Use target dataset.
* Source-free AL setting, Source-trained model, m=500. 
* loss: distance between the sample and the true distribution
* entropy: distance between the sample and the model distribution (source)

In [3]:
def softmax_entropy(x: torch.Tensor, y: torch.Tensor = None) -> torch.Tensor:
    """Entropy of softmax distribution from logits."""
    if y is None:
        if x.shape[1] == 1:
            x = torch.cat([x, -x], dim=1)
        return -(x.softmax(1) * x.log_softmax(1)).sum(1)
    else:
        return - 0.5 * (x.softmax(1) * y.log_softmax(1)).sum(1) - 0.5 * (y.softmax(1) * x.log_softmax(1)).sum(1)

In [23]:
Full_D_T_loader = FastDataLoader(D_T,  
                                 weights=None,
                                 batch_size=config.train.train_bs,
                                 num_workers=config.num_workers, 
                                 sequential=True)

entropy = np.array([], dtype=float)
loss = np.array([], dtype=float)
acc = np.array([], dtype=float)

device = config.device
model = copy.deepcopy(alg.model.cpu()).to(device)
model.eval()


with torch.no_grad():
    for data, target in tqdm(Full_D_T_loader):
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = np.concatenate([loss, config.metric.loss_func(output, target, reduction='none').cpu().numpy()])
        acc = np.concatenate([acc, (target == output.argmax(-1)).float().cpu().numpy()])
        entropy = np.concatenate([entropy, softmax_entropy(output).cpu().numpy()])

  0%|          | 0/497 [00:00<?, ?it/s]

In [24]:
argent = np.argsort(entropy)

# 

display(loss[argent[-300:]].mean())
display(loss[entropy < 1e-6].mean())
display(acc[argent[-300:]].mean())
display(acc[entropy < 1e-2].mean())
display(entropy[argent[-2000:]].min())
display(entropy[argent[:500]].max())
(entropy > 0.1).sum()

display(loss[np.random.choice(argent[-300:], 300, replace=False)].mean())

1.9378976572553317

0.07726989215865558

0.23666666666666666

0.9407971864009379

0.45667752623558044

2.63275734546653e-09

1.9378976572553317

In [30]:
def train_model_on(dataset, subset, device, tol = 1e-3, source_trained=False):
    loader = iter(InfiniteDataLoader(dataset,  
                                weights=None,
                                batch_size=config.train.train_bs,
                                num_workers=config.num_workers, 
                                subset=subset))

    if source_trained:
        model = copy.deepcopy(alg.model.cpu()).to(device)
        display('source_trained')
    else:
        encoder = register.models[config.model.name](config)
        model = nn.Sequential(encoder, 
                              nn.Linear(encoder.n_outputs, config.dataset.num_classes)).to(device)
    model.train()
    # model.eval()
    # display('model.eval()')
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
    lowest_loss = float('inf')
    stop_count = 0
    epoch_pbar = tqdm(range(100))
    for epoch in epoch_pbar:
        mean_loss = []
        for _ in range(100):
            # display(f'{i}')
            data, target = next(loader)
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = config.metric.loss_func(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            mean_loss.append(loss.item())
        mean_loss = np.mean(mean_loss)
        epoch_pbar.set_description(f'ML: {mean_loss:.6f}')

        stop_count += 1
        if mean_loss < lowest_loss:
            stop_count = 0
            lowest_loss = mean_loss
        if stop_count >= 5 or mean_loss < tol:
            break
    return model

@torch.no_grad()
def test_on_env(model, loader, device):
    model.eval()
    test_loss = 0
    test_acc = 0
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        test_loss += config.metric.loss_func(output, target, reduction='sum').item()
        test_acc += config.metric.score_func(target, output) * len(data)
    test_loss /= len(loader.sampler)
    test_acc /= len(loader.sampler)
    return test_loss, test_acc

In [31]:
Full_D_S_loader = FastDataLoader(D_S,  
                            weights=None,
                            batch_size=config.train.train_bs,
                            num_workers=config.num_workers)
Full_D_T_loader = FastDataLoader(D_T,  
                            weights=None,
                            batch_size=config.train.train_bs,
                            num_workers=config.num_workers)

In [36]:
source_like_model = train_model_on(D_T, np.random.choice(len(argent), 2000, replace=False), device, tol=1e-4)
# target_like_model = train_model_on(D_T, np.random.choice(argent[-2000:], 300, replace=False), device, tol=1e-4)
result_df = pd.DataFrame(index=['Model'], columns=['D_S', 'D_T'], dtype=float)
test_idx = 0
result_df.loc['Model', 'D_T'] = test_on_env(source_like_model, Full_D_T_loader, device=device)[test_idx]
result_df.loc['Model', 'D_S'] = test_on_env(source_like_model, Full_D_S_loader, device=device)[test_idx]
display(result_df)

  0%|          | 0/100 [00:00<?, ?it/s]

Unnamed: 0,D_S,D_T
Model,0.950454,0.214905


In [33]:
source_like_model = train_model_on(D_T, argent[:2000], device, tol=1e-4, source_trained=False)
target_like_model = train_model_on(D_T, argent[-2000:], device, tol=1e-4, source_trained=False)
result_f_df = pd.DataFrame(index=['Model_S_like', 'Model_T_like'], columns=['D_S', 'D_T'], dtype=float)
test_idx = 0
result_f_df.loc['Model_S_like', 'D_T'] = test_on_env(source_like_model, Full_D_T_loader, device=device)[test_idx]
result_f_df.loc['Model_S_like', 'D_S'] = test_on_env(source_like_model, Full_D_S_loader, device=device)[test_idx]
result_f_df.loc['Model_T_like', 'D_T'] = test_on_env(target_like_model, Full_D_T_loader, device=device)[test_idx]
result_f_df.loc['Model_T_like', 'D_S'] = test_on_env(target_like_model, Full_D_S_loader, device=device)[test_idx]
display(result_f_df)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Unnamed: 0,D_S,D_T
Model_S_like,0.564108,0.802188
Model_T_like,2.511737,0.34139


Unnamed: 0,D_S,D_T
Model_S_like,0.813965,0.756515
Model_T_like,0.51416,0.916278


In [32]:
source_like_model = train_model_on(D_T, argent[:300], device, tol=1e-4, source_trained=True)
target_like_model = train_model_on(D_T, argent[-300:], device, tol=1e-4, source_trained=True)
result_df = pd.DataFrame(index=['Model_S_like', 'Model_T_like'], columns=['D_S', 'D_T'], dtype=float)
test_idx = 0
result_df.loc['Model_S_like', 'D_T'] = test_on_env(source_like_model, Full_D_T_loader, device=device)[test_idx]
result_df.loc['Model_S_like', 'D_S'] = test_on_env(source_like_model, Full_D_S_loader, device=device)[test_idx]
result_df.loc['Model_T_like', 'D_T'] = test_on_env(target_like_model, Full_D_T_loader, device=device)[test_idx]
result_df.loc['Model_T_like', 'D_S'] = test_on_env(target_like_model, Full_D_S_loader, device=device)[test_idx]
display(result_df)

'source_trained'

  0%|          | 0/100 [00:00<?, ?it/s]

'source_trained'

  0%|          | 0/100 [00:00<?, ?it/s]

Unnamed: 0,D_S,D_T
Model_S_like,0.061888,1.883793
Model_T_like,0.853899,0.772455


In [74]:
source_like_model = train_model_on(D_T, np.concatenate([argent[:2000], argent[-300:]]), device, tol=1e-3, source_trained=True)
target_like_model = train_model_on(D_T, np.concatenate([argent[:300], argent[-2000:]]), device, tol=1e-3, source_trained=True)
result_df = pd.DataFrame(index=['Model_S_like', 'Model_T_like'], columns=['D_S', 'D_T'], dtype=float)
test_idx = 0
result_df.loc['Model_S_like', 'D_T'] = test_on_env(source_like_model, Full_D_T_loader, device=device)[test_idx]
result_df.loc['Model_S_like', 'D_S'] = test_on_env(source_like_model, Full_D_S_loader, device=device)[test_idx]
result_df.loc['Model_T_like', 'D_T'] = test_on_env(target_like_model, Full_D_T_loader, device=device)[test_idx]
result_df.loc['Model_T_like', 'D_S'] = test_on_env(target_like_model, Full_D_S_loader, device=device)[test_idx]
display(result_df)

'source_trained'

  0%|          | 0/100 [00:00<?, ?it/s]

'source_trained'

  0%|          | 0/100 [00:00<?, ?it/s]

Unnamed: 0,D_S,D_T
Model_S_like,0.136056,0.459091
Model_T_like,0.356054,0.200455


In [52]:
df = pd.DataFrame(index=['Current domain', *(i for i in range(4)), 'Frame AVG'], columns=[*(i for i in range(4)), 'Test AVG'], dtype=float)

In [53]:
df.loc[0, 0] = 3.9898989898893748

In [60]:
df.round(4)

Unnamed: 0,0,1,2,3,Test AVG
Current domain,,,,,
0,3.9899,,,,
1,,,,,
2,,,,,
3,,,,,
Frame AVG,,,,,


Unnamed: 0,task,random_seed,exp_round,log_file,gpu_idx,ckpt_root,ckpt_dir,save_tag,other_saved,clean_save,...,atta,config_path,test_ckpt,id_test_ckpt,full_clean,log_path,tensorboard_logdir,tta,device,metric
weight_decay,train,110,1,default,6,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,,False,...,,/data/shurui.gui/Projects/TTA/GraphTTA/configs...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,cuda:6,<GTTA.utils.metric.Metric object at 0x7f813038...
save_gap,train,110,1,default,6,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,,False,...,,/data/shurui.gui/Projects/TTA/GraphTTA/configs...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,cuda:6,<GTTA.utils.metric.Metric object at 0x7f813038...
tr_ctn,train,110,1,default,6,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,,False,...,,/data/shurui.gui/Projects/TTA/GraphTTA/configs...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,cuda:6,<GTTA.utils.metric.Metric object at 0x7f813038...
epoch,train,110,1,default,6,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,,False,...,,/data/shurui.gui/Projects/TTA/GraphTTA/configs...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,cuda:6,<GTTA.utils.metric.Metric object at 0x7f813038...
ctn_epoch,train,110,1,default,6,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,,False,...,,/data/shurui.gui/Projects/TTA/GraphTTA/configs...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,cuda:6,<GTTA.utils.metric.Metric object at 0x7f813038...
alpha,train,110,1,default,6,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,,False,...,,/data/shurui.gui/Projects/TTA/GraphTTA/configs...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,cuda:6,<GTTA.utils.metric.Metric object at 0x7f813038...
stage_stones,train,110,1,default,6,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,,False,...,,/data/shurui.gui/Projects/TTA/GraphTTA/configs...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,cuda:6,<GTTA.utils.metric.Metric object at 0x7f813038...
pre_train,train,110,1,default,6,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,,False,...,,/data/shurui.gui/Projects/TTA/GraphTTA/configs...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,cuda:6,<GTTA.utils.metric.Metric object at 0x7f813038...
lr,train,110,1,default,6,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,,False,...,,/data/shurui.gui/Projects/TTA/GraphTTA/configs...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,cuda:6,<GTTA.utils.metric.Metric object at 0x7f813038...
mile_stones,train,110,1,default,6,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,,False,...,,/data/shurui.gui/Projects/TTA/GraphTTA/configs...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,/data/shurui.gui/Projects/TTA/GraphTTA/storage...,,cuda:6,<GTTA.utils.metric.Metric object at 0x7f813038...
