# Setup

In [28]:
%reload_ext autoreload

In [3]:
# Silence WARNING:root:The use of `check_types` is deprecated and does not have any effect.
# https://github.com/tensorflow/probability/issues/1523
import logging

logger = logging.getLogger()


class CheckTypesFilter(logging.Filter):
    def filter(self, record):
        return "check_types" not in record.getMessage()


logger.addFilter(CheckTypesFilter())

In [8]:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.set_printoptions(precision=3)
import scipy.stats
import einops


from functools import partial
from collections import namedtuple
import itertools
from itertools import repeat
from time import time

import chex
import jax
import jax.random as jr
import jax.numpy as jnp
from jax import vmap, grad, jit, lax
from jax import numpy as jnp
import jax.scipy as jsp

from flax.core import freeze, unfreeze
from flax import linen as nn
import flax

import jaxopt
import optax


from PIL import Image


#jax.config.update("jax_enable_x64", False)



In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds


import torch
from torch.utils.data import TensorDataset
import torchvision.transforms as T

In [2]:
import os 
cpu_count = os.cpu_count()
print(cpu_count)

# Run jax on multiple CPU cores
# https://github.com/google/jax/issues/5506
# https://stackoverflow.com/questions/72328521/jax-pmap-with-multi-core-cpu
import os 
#os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=90'

import jax
print(jax.devices())

96
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]


# Import library code

In [2]:
# https://github.com/probml/probml-utils
import probml_utils
from probml_utils.mlp_flax import MLPNetwork, NeuralNetClassifier
print(MLPNetwork)

2022-11-10 19:49:37.960510: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
2022-11-10 19:49:37.993371: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-11-10 19:49:38.787683: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
2022-11-10 19:49:38.787829: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib


<class 'probml_utils.mlp_flax.MLPNetwork'>


In [3]:

import tta
from tta.utils import *
print(Dataset)

from tta.datasets import *
print(MultipleDomainDataset)

from tta.datasets.chexpert import *
print(MultipleDomainCheXpert)


<class 'tta.utils.Dataset'>
<class 'tta.datasets.MultipleDomainDataset'>
<class 'tta.datasets.chexpert.MultipleDomainCheXpert'>


# Load pre-computed data matrix

In [4]:

from pathlib import Path
root = '/home/kpmurphy/data/CheXpert'
root = Path(root)

data = np.load(root / 'data_matrix.npz', allow_pickle=True)
print(data.files)
print(data['X'].shape)
print(data['YZ'].shape)
print(data['columns'])

['X', 'YZ', 'columns']
(139907, 1376)
(139907, 4)
['PNEUMONIA' 'EFFUSION' 'GENDER' 'AGE_QUANTIZED']


# Fit logistic regression with sklearn

In [3]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
import sklearn
from sklearn.preprocessing import PolynomialFeatures, StandardScaler
from sklearn.pipeline import make_pipeline, Pipeline

In [28]:
X = data['X']
ndx = np.where(data['columns'] == 'EFFUSION')[0][0]
Y = np.array(data['YZ'][:,ndx], dtype=int)

from sklearn.model_selection import train_test_split
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)


N_train  = X_train.shape[0]
N_test  = X_test.shape[0]
print([N_train, N_test])

classifier = Pipeline([
        ('standardscaler', StandardScaler()),
        #('poly', PolynomialFeatures(degree=2)), 
        ('logreg', LogisticRegression(random_state=0, max_iter=500, C=10, solver='sag', multi_class='multinomial'))
])

classifier = LogisticRegression(random_state=0, max_iter=500, C=10, solver='sag', multi_class='multinomial')


[111925, 27982]


In [29]:


#N = 100
N  = N_train
XX = X_train[:N]
YY = Y_train[:N]




In [30]:
%%time
classifier.fit(XX, YY)

In [27]:

probs = classifier.predict_proba(X_test)

y_pred = jnp.argmax(probs, axis=1)
y_pred2 = classifier.predict(X_test)
assert np.allclose(y_pred, y_pred2)

y_true = Y_test
acc = sklearn.metrics.accuracy_score(y_true, y_pred)
print(acc)



0.6214352083482239


# Fit logistic regression with flax

In [5]:
X = data['X']
ndx = np.where(data['columns'] == 'EFFUSION')[0][0]
Y = np.array(data['YZ'][:,ndx], dtype=int)

from sklearn.model_selection import train_test_split
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)

N_train  = X_train.shape[0]
N_test  = X_test.shape[0]
print([N_train, N_test])





[111925, 27982]


In [14]:
#N = 10000 # use subset of data
N  = N_train # use all data

XX = X_train[:N]
nclasses = 2
YY = Y_train[:N]


In [35]:
nhidden = () + (nclasses,) # set nhidden() to get logistic regression
network = MLPNetwork(nhidden)
key = jr.PRNGKey(0)
opt = optax.adamw(1e-3) #'adam+warmup'
mlp = NeuralNetClassifier(network, key, nclasses, l2reg=0, standardize=True,
        batch_size=512, num_epochs=10, print_every=1, optimizer=opt)  


In [36]:
%%time

mlp.fit(XX, YY)

fit optax
epoch 0, train loss 0.686, train accuracy 0.591
epoch 1, train loss 0.671, train accuracy 0.604
epoch 2, train loss 0.668, train accuracy 0.606
epoch 3, train loss 0.668, train accuracy 0.607
epoch 4, train loss 0.667, train accuracy 0.609
epoch 5, train loss 0.666, train accuracy 0.609
epoch 6, train loss 0.667, train accuracy 0.608
epoch 7, train loss 0.666, train accuracy 0.608
epoch 8, train loss 0.666, train accuracy 0.609
epoch 9, train loss 0.666, train accuracy 0.609
CPU times: user 2min 2s, sys: 8.82 s, total: 2min 11s
Wall time: 1min 49s


In [37]:

probs = mlp.predict(X_train)
y_pred = jnp.argmax(probs, axis=1)
acc = jnp.mean(Y_train == y_pred)
print('train accuracy', acc)

probs = mlp.predict(X_test)
y_pred = jnp.argmax(probs, axis=1)
acc = jnp.mean(Y_test == y_pred)
print('test accuracy', acc)

train accuracy 0.6144293
test accuracy 0.60042167


# Make shifted datasets for each domain


In [None]:
root = Path("/home/kpmurphy/data/CheXpert")
dataset_y_column = "EFFUSION"
dataset_z_column = "GENDER"
dataset_use_embedding = True
train_domains_set = [9]
target_domain_count = 512

import random
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
key = jax.random.PRNGKey(seed)
generator = torch.Generator().manual_seed(seed)


dataset = MultipleDomainCheXpert(root, generator, dataset_y_column, dataset_z_column, dataset_use_embedding, train_domains_set, target_domain_count)