In [1]:
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
from PIL import Image
from matplotlib import pyplot as plt
import torchvision
import torch
import os
import random
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, roc_curve, plot_roc_curve, auc
from sklearn.model_selection import train_test_split
from graspologic.cluster import GaussianCluster as GMM
from collections import defaultdict
from proglearn.forest import UncertaintyForest
from sklearn.tree import DecisionTreeClassifier

In [2]:
#kate's script to get auc/95
%run -i evaluate.py 
acorn = 1234
torch.manual_seed(acorn)
np.random.seed(acorn)

torch.cuda.is_available()

if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu"  
    
device = torch.device(dev)  

n_iter = 10
seeds = np.random.randint(10000, size=n_iter)

In [3]:
#process data, filter out only frontal, ap, fillter out uncertainty in classes we care and fill in rest data
def process_data(df):
    
    print('starting size %s' %len(df))
    data = df
    #only use frontal/AP data
    data = data.loc[data['Frontal/Lateral'] == 'Frontal']
    data = data.loc[data['AP/PA'] == 'AP']

    
    category_names = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']
    
    #filter out all uncertainty labels in classes we care about
    data = data[category_names]
    #tread all empty values in these selected cols as 0
    data = data.fillna(0)
    #filter out -1 (uncertain labels)
    data = data.loc[(data.iloc[:, :] !=-1).all(axis=1)]
    #row-idx of the data we care to keep
    fly_list = data.index
    #reselect from orginal of kept rows
    data = df.iloc[fly_list]

    #select the cols we care about
    wanted_cols = ["Path", 'No Finding'] + category_names
    data = data[wanted_cols]
    
    #filter out rows with no label values
    data['sum']  = data.iloc[:, 1:].sum(axis=1)
    fly_list = data.loc[data['sum']>0].index

    
    data = df[wanted_cols].iloc[fly_list]
    # fill all NA and uncertainty as 0     
    data = data.fillna(0)
    data = data.replace(-1,0)

    print("final size %s" %len(data))
    return data



In [4]:
#chexpert data
data_root = '/home/weiwya/teamdrive_bak/weiwei_temp_data'
test_df = pd.read_csv('%s/CheXpert-v1.0-small/valid.csv' %data_root)
test_df = process_data(test_df)

train_full = pd.read_csv('%s/CheXpert-v1.0-small/train.csv' %data_root)
train_full = process_data(train_full)

train_full.head()

starting size 234
final size 132
starting size 223414
final size 92771


Unnamed: 0,Path,No Finding,Atelectasis,Cardiomegaly,Consolidation,Edema,Pleural Effusion
0,CheXpert-v1.0-small/train/patient00001/study1/...,1.0,0.0,0.0,0.0,0.0,0.0
4,CheXpert-v1.0-small/train/patient00003/study1/...,0.0,0.0,0.0,0.0,1.0,0.0
11,CheXpert-v1.0-small/train/patient00006/study1/...,1.0,0.0,0.0,0.0,0.0,0.0
12,CheXpert-v1.0-small/train/patient00007/study1/...,0.0,1.0,1.0,0.0,0.0,0.0
13,CheXpert-v1.0-small/train/patient00007/study2/...,0.0,1.0,0.0,0.0,0.0,0.0


In [5]:
#chexphoto
data_root_photo = '/home/weiwya/teamdrive_bak/weiwei_temp_data/CheXphoto/'
photo_test_df = pd.read_csv('%s/CheXphoto-v1.0/valid.csv' %data_root_photo)
photo_test_df = process_data(photo_test_df)

photo_test_df.head()

starting size 702
final size 396


Unnamed: 0,Path,No Finding,Atelectasis,Cardiomegaly,Consolidation,Edema,Pleural Effusion
0,CheXphoto-v1.0/valid/synthetic/digital/patient...,0.0,0.0,1.0,0.0,0.0,0.0
3,CheXphoto-v1.0/valid/synthetic/digital/patient...,0.0,0.0,0.0,0.0,1.0,0.0
4,CheXphoto-v1.0/valid/synthetic/digital/patient...,1.0,0.0,0.0,0.0,0.0,0.0
5,CheXphoto-v1.0/valid/synthetic/digital/patient...,0.0,1.0,0.0,0.0,0.0,1.0
6,CheXphoto-v1.0/valid/synthetic/digital/patient...,0.0,1.0,1.0,0.0,0.0,0.0


In [6]:
class ChestXRayDataset(torch.utils.data.Dataset):
    def __init__(self, df, transform, data_root):
        #TODO::put something here that perserves aspect ratio
        self.class_names = ['No Finding', 'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']
        self.image_dir = data_root
        self.transform = transform
        self.total = len(df)
        self.image_names = df['Path'].to_list()
        self.labels = df[self.class_names].to_numpy()
                    
    def __len__(self):
        return self.total
    
    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        image_path = os.path.join(self.image_dir, image_name)
        image = self.transform(Image.open(image_path).convert('RGB'))
        label = self.labels[idx]
        return image, label



In [7]:
#transformation part  
# image_size = (320, 320)

image_size = (224, 224)
resnet_mean = [0.485, 0.456, 0.406]
resnet_std = [0.229, 0.224, 0.225]

#Creating a Transformation Object
train_transform = torchvision.transforms.Compose([
    #Converting images to the size that the model expects
    torchvision.transforms.Resize(size=image_size),
    torchvision.transforms.RandomHorizontalFlip(), #A RandomHorizontalFlip to augment our data
    torchvision.transforms.ToTensor(), #Converting to tensor
    #Normalizing the data to the data that the ResNet18 was trained on
    torchvision.transforms.Normalize(mean = resnet_mean ,
                                    std = resnet_std) 
    
])


#Creating a Transformation Object
test_transform = torchvision.transforms.Compose([
    #Converting images to the size that the model expects
    torchvision.transforms.Resize(size=image_size),
    # We don't do data augmentation in the test/val set    
    torchvision.transforms.ToTensor(), #Converting to tensor
    torchvision.transforms.Normalize(mean = resnet_mean,
                                    std = resnet_std) 
    
])

In [8]:
def make_fine_labels(features, org_labels, fine_clfs,  n_fine_classes):
    n_samples = len(features)
    fine_labels = np.zeros( (n_samples, n_fine_classes))
    print(n_samples, n_fine_classes)
    curr = 0
    for idx , clf in enumerate(fine_clfs):
        truth = org_labels[:, idx]
        for row, v in enumerate(truth):
            if v == 1.:
                p = clf.predict(features[row].reshape(1,-1)) + curr
                fine_labels[row, p] = 1.
        
        curr +=clf.n_components_              
    return fine_labels

def make_coarse_labels(fine_labels, coarse_to_fine):
    n_samples = len(fine_labels)
    n_labels = len(coarse_to_fine)
    labels = np.zeros((n_samples, n_labels))
    fine_to_coarse = defaultdict(list)
    #get all coarse labels for a fine label
    for k, v in coarse_to_fine.items():
        for vv in v:
            fine_to_coarse[vv].append(k)

    #map each fine label to a coarse label
    for row, f in enumerate(fine_labels):
        for col, v in enumerate(f):
            if v == 1.:
                cc  = fine_to_coarse[col]
                for c in cc:
                    labels[row, c] = 1.0
 
    return labels
    
#extract feature from image tensors    
def extract_features(feature_extractor, tensors, batch = 72):

    curr = 0
    total = len(tensors)
    res = []
    while curr < total:
        curr_batch = tensors[curr: curr+batch]
        tensor_gpu = curr_batch.to(device)
        outputs = feature_extractor(tensor_gpu)        
        outputs = torch.Tensor.cpu(outputs)
        outputs = outputs.detach().numpy()
        n_samples = outputs.shape[0]
        n_features = outputs.shape[1]
        outputs.resize(n_samples, n_features)
        res.append(outputs)
        curr+= batch
    res = np.vstack(res)
    print(res.shape)
    return res
#take any model, use its penultimate layer as output features
def extract_fetures_targets(feature_extractor, dl):    
    features = []
    targets = []
    for val_step, (images, labels) in enumerate(dl):
        imagesGPU = images.to(device)      
        outputs = feature_extractor(imagesGPU)        
        outputs = torch.Tensor.cpu(outputs)
        outputs = outputs.detach().numpy()
        features.append(outputs)
        targets.append(labels)
        
    features = np.vstack(features)
    targets = np.vstack(targets)
    dim = features.shape[1]
    features= features.reshape(len(dl.dataset), dim)
    torch.cuda.empty_cache()
    return features, targets


#build model using Resnet50 as backbone
class Resnet50Base(torch.nn.Module):
    def __init__(self, n_classes, name, starter_model=None):
        super().__init__()
        #no basemodel, used pre-train resnet
        if starter_model is None:
            resnet = torchvision.models.resnet18(pretrained=True)
            resnet.fc = torch.nn.Sequential(
                torch.nn.Dropout(p=0.25),
                torch.nn.Linear(in_features=512, out_features=n_classes)
            )
            self.base_model = resnet
        else:
            base_model = starter_model.base_model
            base_model.fc = torch.nn.Sequential(
                torch.nn.Dropout(p=0.25),
                torch.nn.Linear(in_features=512, out_features=n_classes)
            )
            self.base_model = base_model
            
        self.sigm = torch.nn.Sigmoid()
        self.name = name

    def forward(self, x):
        return self.sigm(self.base_model(x))
    


def get_feature_extractor(model):    
    tt = model.base_model
    modules=list(tt.children())[:-1]
    feature_extractor = torch.nn.Sequential(*modules)
    for p in feature_extractor.parameters():
        p.requires_grad = False
    feature_extractor.to(device)
    return feature_extractor


In [9]:
def eval_auc_kate(targets, predicts, class_names, alpha= 0.95):
    return eval_auc(targets, predicts, class_names, alpha)
       
def get_prediction(model,  dl):
    model.eval()
    predicts = []
    targets = []
    total_loss = 0
    class_lookup = dl.dataset.class_names
    n_class = len(class_lookup)

    for val_step, (images, labels) in enumerate(dl):

        imagesGPU, labelsGPU = images.to(device), labels.to(device)        
        outputs = model(imagesGPU)
        outputs = torch.Tensor.cpu(outputs)
        predicts.append(outputs.detach().numpy())
        targets.append(labels)

    predicts = np.vstack(predicts)
    targets = np.vstack(targets)
    return predicts, targets

def eval_model(model, dl, alpha =0.95, verbose=True):
    predicts, targets = get_prediction(model, dl)
    aucs = eval_auc(targets, predicts, dl.dataset.class_names, alpha=alpha)
    if verbose:
        for k, v in aucs.items():
            print(k, v)
    return aucs

In [10]:
def gen_fine_clf(features, labels, n_cluster_min=3, n_cluster_max=5, return_cond_mean=False):
    fine_clfs = []
    n_features = labels.shape[1]
    curr = 0
    fine_to_org ={}
    
    conditional_means = []
    for i in range(n_features):
        ll = labels[:, i]
        selected_idx = np.where(ll==1.0)[0]
        xx = features[selected_idx]
           
        clf = GMM(min_components=n_cluster_min, max_components=n_cluster_max, reg_covar=1e-3).fit(xx) 
        pp = clf.predict(xx) + curr
        curr += clf.n_components_

        unique_y = np.unique(pp)
        for y in unique_y:
            fine_to_org[y] = i
            
        fine_clfs.append(clf) 
        if return_cond_mean:
            means = np.array([
                np.mean(features[np.where(pp == c)[0]], axis=0) for c in unique_y])
            conditional_means.append(means)
            
    if not return_cond_mean:
        return fine_clfs, fine_to_org, curr
    else:
        conditional_means = np.vstack(conditional_means)
        return fine_clfs, fine_to_org, curr, conditional_means

    
#1st cluster within a label
#then cluster means of each cluster to generate coarse label
def gen_coarse_fine_clfs(features, labels, n_cluster_min=3, n_cluster_max=5):
    
    fine_clf, fine_to_org, n_fine_clusters, conditional_means =  gen_fine_clf(features, labels, n_cluster_min, 
                                                           n_cluster_max, return_cond_mean=True)
    coarse_clf  = GMM(min_components= 7, max_components= 12, reg_covar=1e-3, tol=1e-5)
    coarse_clf.fit(conditional_means)
    total_coarse_labels = coarse_clf.n_components_
    
    coarse_to_fine = defaultdict(list)
    pp = coarse_clf.predict(conditional_means)
    
    for i, p in enumerate(pp):
        coarse_to_fine[p].append(i)
    
    return fine_clf, fine_to_org, coarse_clf, coarse_to_fine   
    
    

In [11]:
def train_model(epochs, train_model, train_loss_fn, train_optimizer, dl_train, dl_valid, eval_fn):
    best_auc_dic = None

    for e in range(epochs):
        print(e)
        train_loss = 0.        
        train_model.train() 
        with tqdm(dl_train, unit="batch") as tepoch:
            for images, labels in tepoch:
                images, targets = images.to(device), labels.to(device)
                train_optimizer.zero_grad()
                outputs = train_model(images)
                loss = train_loss_fn(outputs, targets.type(torch.float))
                #Once we get the loss we need to take a gradient step
                loss.backward() #Back propogation
                train_optimizer.step() #Completes the gradient step by updating all the parameter values(We are using all parameters)
                train_loss += loss.item() #Loss is a tensor which can't be added to train_loss so .item() converts it to float                
                tepoch.set_postfix(loss=loss.item())
        
        print('train_loss %s ' %(train_loss / len(dl_train)))
        
        curr_auc_dic = eval_fn(train_model, dl_valid, verbose=False)
        if (best_auc_dic is None)  or  (best_auc_dic['Average']['AUC'] < curr_auc_dic['Average']['AUC']):
            best_auc_dic = curr_auc_dic
            torch.save(train_model.state_dict(), 'CheXpert_%s_resnet50' %(train_model.name) )
            print('curr best %s' %best_auc_dic['Average']['AUC'])


    return 'CheXpert_%s_resnet50' %(train_model.name)



In [12]:
def make_fine_labels(features, org_labels, fine_clfs,  n_fine_classes):
    n_samples = len(features)
    fine_labels = np.zeros( (n_samples, n_fine_classes))
    curr = 0
    for idx , clf in enumerate(fine_clfs):
        truth = org_labels[:, idx]
        for row, v in enumerate(truth):
            if v == 1.:
                p = clf.predict(features[row].reshape(1,-1)) + curr
                fine_labels[row, p] = 1.
        
        curr +=clf.n_components_              
    return fine_labels

def make_coarse_labels(fine_labels, coarse_to_fine):
    n_samples = len(fine_labels)
    n_labels = len(coarse_to_fine)
    labels = np.zeros((n_samples, n_labels))
    fine_to_coarse = defaultdict(list)
    #get all coarse labels for a fine label
    for k, v in coarse_to_fine.items():
        for vv in v:
            fine_to_coarse[vv].append(k)

    #map each fine label to a coarse label
    for row, f in enumerate(fine_labels):
        for col, v in enumerate(f):
            if v == 1.:
                cc  = fine_to_coarse[col]
                for c in cc:
                    labels[row, c] = 1.0
 
    return labels

def get_fine_data_target(features, fine_targets,  wanted_labels):
    n_samples = features.shape[0]
    dim = len(wanted_labels)
    selected_fine_targets = fine_targets[:, wanted_labels]    
    #then select the rows where there is contet
    kept_features = []
    kept_targets = []
    for row in range(n_samples):
        v = selected_fine_targets[row]
        if sum(v) > 0:
            kept_features.append(features[row])
            kept_targets.append(v)
    kept_targets = np.vstack(kept_targets)
    kept_features = np.vstack(kept_features)
    return kept_features, kept_targets

def predict_hierachy(features, coarse_clf, clf_fines, fine_to_org, coarse_to_fine, n_actual_class):
    n_samples = len(features)
    n_fine_classes = len(fine_to_org)
    predict_coarse = coarse_clf.predict(features)
    predics_fines= np.zeros((n_samples, n_fine_classes))
    
    for k, clf in clf_fines.items():
        cols = coarse_to_fine[k]
        if len(cols) == 1:
            c = cols[0]
            predics_fines[:, c] += 1
        else:
            pp = clf.predict(features)
            for idx, c in enumerate(cols):
                predics_fines[:, c] += pp[:, idx]

    predicts = np.zeros((n_samples, n_actual_classes ))
    fine_to_coarse = defaultdict(list)

    for k, v in coarse_to_fine.items():
        for vv in v:
            fine_to_coarse[vv].append(k)

    counts = defaultdict(int)
    for fine, org in fine_to_org.items():
        cc = fine_to_coarse[fine]
        p_fine = predics_fines[:, fine]
        for c in cc:
            p_coarse = predict_coarse[:, c]
            predicts[:, org] += p_coarse*p_fine
            counts[org] += 1
    
    for i in range(n_actual_class):
        if counts[i] != 0:
            predicts[:, i] /=counts[i]
    return predicts

#getto bagging DF, not configured posterior 
class SimpleDF():
    def __init__(self, n_estimators, max_depth, sample_size = 0.75):
        self.n_estimators = n_estimators
        self.max_depth = max_depth
        self.sample_size = sample_size
        self.trees = None
        self.dim = 0
    
    def fit(self, X, y, verbose = False):
        self.trees = []
        self.dim = y.shape[1]
        for i in range(self.n_estimators):
            if self.sample_size < 1:
                cX, _, cY, _ = train_test_split(X, y, train_size=self.sample_size)
            else:
                cX = X 
                cY = y
            df = DecisionTreeClassifier(max_depth=self.max_depth)
            df.fit(cX,cY)
            self.trees.append(df)
            if verbose:
                print('build tree %s' %i+1)
        
    def predict(self, X, verbose=False):
        n_samples = X.shape[0]
        res = np.zeros((n_samples, self.dim))
        for i in range(self.n_estimators):
            p = self.trees[i].predict(X)
            res += p
            if verbose:
                print('predicted %s' %i+1)
        return res /self.n_estimators
    
        
        


In [13]:
acorn = 1234
torch.manual_seed(acorn)
np.random.seed(acorn)


seeds = np.random.randint(10000, size=1000)
batch_size = 16
train_size = 0.01

#generate n_iter times of train/validate split
trains, validates = [],[]
for i in range(100):
    train, validate = train_test_split(train_full, test_size=1-train_size, random_state=seeds[i], shuffle=True)
    trains.append(train)
    validates.append(validate)
        

#acutal test df for chexperd
test_dataset = ChestXRayDataset(test_df, test_transform, data_root)
dl_test = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
#actual test df for 
test_photo_dataset = ChestXRayDataset(photo_test_df, test_transform, data_root_photo)
dl_test_photo = torch.utils.data.DataLoader(test_photo_dataset, batch_size=batch_size, shuffle=True)

In [None]:
base_test_auc = []
hierachy_test_auc = []

base_test_auc_photo = []
hierachy_test_auc_photo = []

debug_auc =[]
debug_auc_photo = []

debug_hierachy_auc = []
debug_hierachy_auc_photo = []

n_iter = 20
train_epoch= 10

for iteration in range(n_iter):
    print(iteration)
    torch.cuda.empty_cache()
    train_df, validate_df = trains[iteration], validates[iteration][:1000]
    train_dataset = ChestXRayDataset(train_df, train_transform, data_root)
    valid_dataset = ChestXRayDataset(validate_df, test_transform, data_root)

    dl_train = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    dl_valid = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
    
    # Initialize the model
    c_model = Resnet50Base(len(train_dataset.class_names), 'base_model_%s_%s' %(train_size, iteration))
    c_model.to(device)
    c_loss_fn = torch.nn.BCELoss().to(device)
    c_optimizer = torch.optim.Adam(c_model.parameters(), lr=5e-4)
    
    best_model_name = train_model(train_epoch, c_model, c_loss_fn, c_optimizer, dl_train, dl_valid, eval_model)
    print('done eval base model ')
    
    #reload best model for embedding
    print('using model %s' %best_model_name)
    c_model = Resnet50Base(len(train_dataset.class_names), '')
    c_model.load_state_dict(torch.load(best_model_name))
    c_model.to(device)

    #evalute test sets at per formance
    test_auc = eval_model(c_model, dl_test, verbose=False)
    test_photo_auc = eval_model(c_model, dl_test_photo, verbose=False)
    base_test_auc.append(test_auc)
    base_test_auc_photo.append(test_photo_auc)
    print('auc at best base model %s %s' %(test_auc['Average']['AUC'], test_photo_auc['Average']['AUC'] ))
    debug_auc.append(test_auc['Average']['AUC'])
    debug_auc_photo.append(test_photo_auc['Average']['AUC'] )
    
    
    
    embd_model = get_feature_extractor(c_model)
    embd_model.to(device)
    train_features, train_targets = extract_fetures_targets(embd_model, dl_train)
    
    fine_clfs, fine_to_org, coarse_clf, coarse_to_fine = gen_coarse_fine_clfs(train_features, train_targets, n_cluster_min=1, n_cluster_max=12 )
    n_fine_classes = len(fine_to_org)
    n_coarse_classes = len(coarse_to_fine)
    print('done generating clfs: n_fine_clf: %s n_coarse: %s' %(n_fine_classes, n_coarse_classes))
    
    validate_features, validate_target = extract_fetures_targets(embd_model, dl_valid)
    test_features, test_target = extract_fetures_targets(embd_model, dl_test)
    test_photo_features, test_photo_targets = extract_fetures_targets(embd_model, dl_test_photo)
    
    train_fine_targets = make_fine_labels(train_features, train_targets, fine_clfs, len(fine_to_org))
    train_coarse_targets = make_coarse_labels(train_fine_targets, coarse_to_fine)
    coarse_clf = SimpleDF(n_estimators=100, sample_size=0.75, max_depth= 20)
    coarse_clf.fit(train_features, train_coarse_targets)
    clf_fines = {}
    for k, v in coarse_to_fine.items():
        f, t = get_fine_data_target(train_features, train_fine_targets, v)
        clf = SimpleDF(n_estimators=100, sample_size=1.0, max_depth=10)
        clf.fit(f,t)
        clf_fines[k] = clf
    n_actual_classes = train_targets.shape[1]
    predicts_test = predict_hierachy(test_features, coarse_clf, clf_fines, fine_to_org, coarse_to_fine, n_actual_classes)
    test_auc = eval_auc(test_target, predicts_test, dl_train.dataset.class_names, alpha=0.95)
    predicts_photo = predict_hierachy(test_photo_features, coarse_clf, clf_fines, fine_to_org, coarse_to_fine, n_actual_classes)
    test_photo_auc = eval_auc(test_photo_targets, predicts_photo, dl_train.dataset.class_names, alpha=0.95)

    print('auc at best hierachy model %s %s' %(test_auc['Average']['AUC'], test_photo_auc['Average']['AUC'] ))
    debug_hierachy_auc.append(test_auc['Average']['AUC'])
    debug_hierachy_auc_photo.append(test_photo_auc['Average']['AUC'])
    
    hierachy_test_auc.append(test_auc)
    hierachy_test_auc_photo.append(test_photo_auc)
    
    print(np.average(debug_auc), np.average(debug_auc_photo))
    print(np.average(debug_hierachy_auc), np.average(debug_hierachy_auc_photo))
    print('!!!!!!!!!!!')

    

0


  2%|▏         | 1/58 [00:00<00:05,  9.81batch/s, loss=0.653]

0


100%|██████████| 58/58 [00:05<00:00, 11.00batch/s, loss=0.534]


train_loss 0.507447447242408 


  3%|▎         | 2/58 [00:00<00:04, 11.29batch/s, loss=0.387]

curr best 0.7580720362886509
1


100%|██████████| 58/58 [00:05<00:00, 11.06batch/s, loss=0.435]


train_loss 0.44875717933835657 


  3%|▎         | 2/58 [00:00<00:05, 10.96batch/s, loss=0.409]

curr best 0.7743931129394735
2


100%|██████████| 58/58 [00:05<00:00, 10.36batch/s, loss=0.452]


train_loss 0.423359257907703 


  3%|▎         | 2/58 [00:00<00:04, 11.71batch/s, loss=0.395]

3


100%|██████████| 58/58 [00:05<00:00, 10.88batch/s, loss=0.428]


train_loss 0.4118287347514054 


  3%|▎         | 2/58 [00:00<00:05, 11.03batch/s, loss=0.253]

curr best 0.7852306558238806
4


100%|██████████| 58/58 [00:05<00:00, 11.05batch/s, loss=0.32] 


train_loss 0.3841230941229853 


  0%|          | 0/58 [00:00<?, ?batch/s]

curr best 0.7944779777998248
5


100%|██████████| 58/58 [00:06<00:00,  8.85batch/s, loss=0.451]


train_loss 0.35425084829330444 


  0%|          | 0/58 [00:00<?, ?batch/s, loss=0.315]

curr best 0.7978933893687644
6


100%|██████████| 58/58 [00:05<00:00, 10.88batch/s, loss=0.278]


train_loss 0.31253676758757953 


  3%|▎         | 2/58 [00:00<00:04, 11.85batch/s, loss=0.25]

7


100%|██████████| 58/58 [00:05<00:00, 10.47batch/s, loss=0.266]


train_loss 0.2789884949552602 


  2%|▏         | 1/58 [00:00<00:08,  7.09batch/s, loss=0.243]

8


 62%|██████▏   | 36/58 [00:05<00:02,  9.57batch/s, loss=0.278]

In [None]:
0.624059829059829 0.6343837712356231