In [None]:
import os 

In [None]:
os.chdir("drive/MyDrive/bmi/")

In [None]:
! pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.25.1-py3-none-any.whl (5.8 MB)
[K     |████████████████████████████████| 5.8 MB 15.6 MB/s 
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 55.7 MB/s 
[?25hCollecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 73.0 MB/s 
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.11.1 tokenizers-0.13.2 transformers-4.25.1


In [None]:
import os 
import argparse 
import random 
import pandas as pd
import numpy as np 
import collections 
from tqdm import tqdm 
from datetime import datetime

# PyTorch libraries 
import torch 
import torch.nn as nn 
from torch.utils.data import DataLoader 
from torchvision.datasets import DatasetFolder
from torchvision import models, transforms 

# Hugging Face datasets 
#import datasets 

# Transformers libraries 
from transformers import TrainingArguments, Trainer
from transformers import ViTFeatureExtractor, ViTForImageClassification
from transformers import get_linear_schedule_with_warmup 

# import sklearn models 
# from sklearn.neighbors import KNeighborsClassifier

# simple models
from models import LogisticRegression, BasicCNNModel, DenseCNNModel, BasicCNNCountryModel, ViTCountryModel, ViTMosaiksModel
from SatelliteImageDataset import SatelliteImageDataset, SatelliteImageMetadataDataset, SatelliteImageMosaiksDataset

from sklearn.metrics import confusion_matrix

In [None]:
RANDOM_SEED = 231 
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

class ImageClassificationCollator:
    def __init__(self, feature_extractor, transforms = False, metadata = False, mosaiks = False): 
        self.feature_extractor = feature_extractor
        self.transforms = transforms 
        self.metadata = metadata
        self.mosaiks = mosaiks

    def __call__(self, batch):
        if self.transforms: 
            transformed = [self.feature_extractor(x[0].cpu().detach().numpy()) for x in batch]
            encodings = {"pixel_values":torch.stack(transformed)}
        else: 
            encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')   
        encodings['labels'] = torch.tensor([x[1] for x in batch],  dtype=torch.long)
        
        if self.metadata: 
            if "country" in self.metadata:
                encodings['country'] = torch.tensor([x[2] for x in batch])
        elif self.mosaiks: 
            encodings['mosaiks_features'] = torch.tensor(np.array([x[2] for x in batch]), dtype = torch.float32)

        return encodings

# create model and collator
def create_model_and_collator(args, model_name, metadata = None, cnt_id_map = None):

    if metadata:
        feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
        collator = ImageClassificationCollator(feature_extractor, metadata=metadata)
        collators = (collator, collator)
        if model_name in ['basic_cnn']:
            model = BasicCNNCountryModel(n_classes=CLASSES, cnt_id_map = cnt_id_map, num_country_embeddings=len(cnt_id_map))
        elif model_name == "ViT":
            model = ViTCountryModel(n_classes=CLASSES, cnt_id_map = cnt_id_map, num_country_embeddings=len(cnt_id_map))
    elif args['mosaiks']: 
        feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
        collator = ImageClassificationCollator(feature_extractor, mosaiks=args['mosaiks'])
        collators = (collator, collator)
        model = ViTMosaiksModel(n_classes=CLASSES, mosaiks_dim = 3999)
    elif model_name == "ViT":
        feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
        collator = ImageClassificationCollator(feature_extractor)
        collators = (collator, collator)
        model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=CLASSES)

    elif model_name in ['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet']:
        # note all models expect image of (3, 224, 224)

        train_data_transforms = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomResizedCrop(224), # i.e. want 224 by 224 
            transforms.RandomHorizontalFlip(),  
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        val_data_transforms = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(224), # i.e. want 224 by 224 
            transforms.CenterCrop(224), 
            transforms.ToTensor(), 
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        train_collator = ImageClassificationCollator(train_data_transforms, transforms=True)
        val_collator = ImageClassificationCollator(val_data_transforms, transforms=True)

        collators = (train_collator, val_collator)

        if model_name == 'resnet':
            model = models.resnet18(pretrained=True)
            model.fc = nn.Linear(model.fc.in_features, CLASSES)

        elif model_name == 'alexnet':
            model = models.alexnet(pretrained=True)
            model.classifier[6] = nn.Linear(model.classifier[6].in_features, CLASSES)

        elif model_name == 'vgg':
            model = models.vgg11_bn(pretrained=True)
            model.classifier[6] = nn.Linear(model.classifier[6].in_features, CLASSES)

        elif model_name == 'squeezenet': 
            model = models.squeezenet1_0(pretrained=True)
            model.classifier[1] = nn.Conv2d(512, CLASSES, kernel_size=1, stride=1)
            model.num_classes = CLASSES

        else: 
            # dense net 
            model = models.densenet121(pretrained=True)
            model.classifier = nn.Linear(model.classifier.in_features, CLASSES) 

    elif model_name in ['basic_cnn', 'basic_cnn_novit', 'dense_cnn', 'logistic_regression']:
        # ADD IN transforms though feature extractor might be easier 
        if "novit" in model_name:
            train_data_transforms = transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomResizedCrop(224), # i.e. want 224 by 224 
                transforms.RandomHorizontalFlip(),  
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
            val_data_transforms = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize(224), # i.e. want 224 by 224 
                transforms.CenterCrop(224),
                transforms.ToTensor(), 
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
            train_collator = ImageClassificationCollator(train_data_transforms, transforms=True)
            val_collator = ImageClassificationCollator(val_data_transforms, transforms=True)

            collators = (train_collator, val_collator)
        else: 
            feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
            collator = ImageClassificationCollator(feature_extractor)
            collators = (collator, collator)
        # TODO: add support for model!
        if model_name == "logistic_regression":
            model = LogisticRegression(n_classes=CLASSES)
        elif "basic_cnn" in model_name:
            model = BasicCNNModel(n_classes=CLASSES)
        elif model_name == "dense_cnn":
            model = DenseCNNModel(n_classes=CLASSES)

    else: 
        raise NotImplementedError

    print(f'Model name: {model_name}')

    return collators, model 


def create_dataset(args, collator_fns, metadata = None, cnt_id_map = None, val_split = 0.15):

    def npy_loader(path):
        sample = torch.from_numpy(np.load(path))
        return sample 
    
    # load in dataset frmom directory 
    if metadata:
        dataset = SatelliteImageMetadataDataset(
            root = args['data_dir'], 
            csv_path = args['csv_file'], 
            outcome = args['outcome'], 
            loader = npy_loader
        )
    elif args['mosaiks']: 
        dataset = SatelliteImageMosaiksDataset(
            root = args['data_dir'], 
            csv_path = args['csv_file'], 
            outcome = args['outcome'], 
            mosaiks_csv_path = args['mosaiks_csv_file'], 
            loader = npy_loader
        )
    else:
        dataset = SatelliteImageDataset(
            root = args['data_dir'], 
            csv_path = args['csv_file'],
            outcome = args['outcome'], 
            loader = npy_loader
        )

    # IDEALLY we would like same sampling...

    # split up into train val data 
    if os.path.isfile("indices_perm2.npy"):
        indices = np.load("indices_perm2.npy") 
    else:
        indices = torch.randperm(len(dataset)).tolist()
        np.save("indices_perm2.npy", indices)
    
    n_val = int(np.floor(len(indices) * val_split))
    train_ds = torch.utils.data.Subset(dataset, indices[:-n_val])
    val_ds = torch.utils.data.Subset(dataset, indices[-n_val:])

    train_dl = DataLoader(train_ds, batch_size=args['batch_size'], collate_fn=collator_fns[0], shuffle = 1)
    val_dl = DataLoader(val_ds, batch_size=args['batch_size'], collate_fn=collator_fns[1], shuffle=0)

    return [train_dl, val_dl]

def dataset_statistics(args, dataset_loader):
    label_stats = collections.Counter()
    for i, batch in enumerate(dataset_loader):
        inputs, labels = batch['pixel_values'], batch['labels']
        labels = labels.cpu().numpy().flatten()
        label_stats += collections.Counter(labels)
    return label_stats


def measure_accuracy(outputs, labels):
    preds = np.argmax(outputs, axis = 1).flatten()
    labels = labels.flatten()
    correct = np.sum(preds == labels)
    c_matrix = confusion_matrix(labels, preds, labels=CLASS_NAMES)
    return correct, len(labels), c_matrix 

def validation(args, val_loader, model, criterion, metadata, device, name = 'Validation', write_file=None):

    model.eval()
    total_loss = 0. 
    total_correct = 0 
    total_sample = 0 
    total_confusion = np.zeros((CLASSES, CLASSES))

    for i, batch in enumerate(tqdm(val_loader)):
        inputs, labels = batch['pixel_values'], batch['labels'] 
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            if metadata: 
                country = batch['country'].to(device)
                outputs = model(inputs, country)
            elif args['mosaiks']: 
                mosaiks_features = batch['mosaiks_features'].to(device)
                outputs = model(inputs, mosaiks_features)
            elif args['model_name'] in [
            'basic_cnn', 'basic_cnn_novit', 'dense_cnn', 'logistic_regression',
            'resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet'
            ]:
                outputs = model(inputs)
            else: 
                outputs = model(inputs)['logits'] 

        loss = criterion(outputs, labels)

        logits = outputs 
        total_loss += loss.cpu().item()

        correct_n, sample_n, c_matrix = measure_accuracy(logits.cpu().numpy(), labels.cpu().numpy())
        total_correct += correct_n 
        total_sample += sample_n 
        total_confusion += c_matrix 

    print(f'*** Accuracy on the {name} set: {total_correct/total_sample}')
    print(f'*** Weighted accuracy on the {name} set: {np.mean( np.diag(total_confusion) / np.sum(total_confusion, 1) )}')
    print(f'*** Confusion matrix:\n{total_confusion}')
    if write_file:
        write_file.write(f'*** Accuracy on the {name} set: {total_correct/total_sample}\n')
        write_file.write(f'*** Weighted accuracy on the {name} set: {np.mean( np.diag(total_confusion) / np.sum(total_confusion, 1) )}')
        write_file.write(f'*** Confusion matrix:\n{total_confusion}\n')

    return total_loss, float(total_correct / total_sample) * 100



def train(args, data_loaders, epoch_n, model, optim, scheduler, criterion, metadata, device, write_file=None):
    print("\n>>> Training starts...")

    if write_file: 
        write_file.write("\n>>> Training starts...")

    model.train()

    best_val_acc = 0
    for epoch in range(epoch_n):
        print("*** Epoch:", epoch)
        total_train_loss = 0. 
        total_correct = 0
        total_sample = 0

        for i, batch in enumerate(tqdm(data_loaders[0])): 
            optim.zero_grad()
            inputs, labels = batch['pixel_values'], batch['labels'] 
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            # forward pass 
            if metadata: 
                country = batch['country'].to(device)
                outputs = model(inputs, country)
            elif args['mosaiks']: 
                mosaiks_features = batch['mosaiks_features'].to(device)
                outputs = model(inputs, mosaiks_features)
            elif args['model_name'] in [
                'basic_cnn', 'basic_cnn_novit', 'dense_cnn', 'logistic_regression', 
                'resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet'
            ]:
                outputs = model(inputs)
            else: 
                outputs = model(inputs)['logits']

            loss = criterion(outputs, labels)
            logits = outputs
            correct_n, sample_n, c_matrix = measure_accuracy(logits.cpu().detach().numpy(), labels.cpu().detach().numpy())
            total_correct += correct_n 
            total_sample += sample_n 

            total_train_loss += loss.cpu().item()

            # backward pass 
            loss.backward()
            optim.step()

            if scheduler: scheduler.step()

            if i % args['val_every'] == 0: 
                print(f'*** Average Loss: {total_train_loss / (i+1)}')
                print(f'*** Running accuracy on the train set: {total_correct/total_sample}')
                if write_file:
                    write_file.write(f'\nEpoch: {epoch}, Step: {i}\n')
                    write_file.write(f'*** Loss: {loss}\n')
                    write_file.write(f'*** Running accuracy on the train set: {total_correct/total_sample}\n')

                _, val_acc = validation(args, data_loaders[1], model, criterion, metadata, device, write_file=write_file)

                model.train()

                if best_val_acc < val_acc: 
                    best_val_acc = val_acc 

                    if args['save_path']:
                        if args['mosaiks']:
                          with open(args['save_path'] + ".pkl", "wb") as f:
                            pickle.dump(model, f)
                        elif args['model_name'] in ['ViT']:
                            model.save_pretrained(args['save_path'])
                        else: 
                            torch.save(model.state_dict(), args['save_path'])



In [None]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F

from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers import ViTForImageClassification, ViTModel, ViTConfig

class ViTMosaiksModel(nn.Module):

	def __init__(self, n_classes, mosaiks_dim = 64, mlp_dim = 128):
		
		super().__init__()

		self.n_classes = n_classes
		# applies pooling layer 
		configuration = ViTConfig()
		self.model = ViTModel(configuration).from_pretrained('google/vit-base-patch16-224-in21k')
		
		hidden_dim = 768
		self.mlp = nn.Sequential(
            nn.Linear(hidden_dim + mosaiks_dim, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, mlp_dim),
            nn.ReLU(),            
            nn.Linear(mlp_dim, self.n_classes)
        )
	
	def forward(self, X, mosaiks_features):
		device = 'cuda' if torch.cuda.is_available() else 'cpu'
		model_out = self.model(X)['pooler_output']
		mosaiks_features = mosaiks_features.to(device)
		concat_output = torch.cat((model_out, mosaiks_features), dim=1)
		logits = self.mlp(concat_output)

		return logits

In [None]:
args = {
  'data_dir': 'west_africa',
  'csv_file': 'west_africa_df', 
  'mosaiks_csv_file': 'west_africa_mosaiks_feats',
  'mosaiks': False, 
  'outcome': 'Mean_BMI_bin',
  'model_name':'ViT',

  'val_every': 200, 
  'batch_size': 64, 
  'lr':2e-5, 
  'eps':1e-8
}

# set device to GPU if possible
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Number of classes 
CLASSES = 3
CLASS_NAMES = [i for i in range(CLASSES)]

epoch_n = 10

# read df 
df = pd.read_csv(args['csv_file'] + ".csv")

# if args.metadata and args.mosaiks: 
#     raise NotImplementedError("Functionality for both mosaiks and metadata has not been implemented yet.")

cnt_id_map = None
metadata = None
# if args.metadata: 
#     metadata = ["country"]
#     unique_countries = list(set(df["country"]))
#     unique_countries_int = [int(str(ord(c[0])) + str(ord(c[1]))) for c in unique_countries]
#     cnt_id_map = {float(v):k for k, v in enumerate(set(unique_countries_int))}

# if filename is None: 
#     filename = f'./results/{args.model_name}/{datetime.now()}.txt'

write_file = None
# write_file = open(filename, "w")

# if write_file:
#     write_file.write(f'*** args: {args}\n\n')

# create model 
collators, model = create_model_and_collator(
    args = args, 
    model_name = args['model_name'], 
    metadata=metadata, cnt_id_map = cnt_id_map

)
model.to(device)

In [None]:
# load data 
data_loaders = create_dataset(
    args = args, collator_fns = collators, metadata = metadata, cnt_id_map = cnt_id_map
)

# train_label_stats = dataset_statistics(args, data_loaders[0])
# val_label_stats = dataset_statistics(args, data_loaders[1])
# print(f'*** Training set label statistics: {train_label_stats}')
# print(f'*** Validation set label statistics: {val_label_stats}')

# if write_file:
#     write_file.write(f'*** Training set label statistics: {train_label_stats}')
#     write_file.write(f'*** Validation set label statistics: {val_label_stats}')	


# if args.model_name in ['logistic_regression', 'basic_cnn', 'dense_cnn']:
#     optim = torch.optim.Adam(params = model.parameters())
# else: 
optim = torch.optim.AdamW(params=model.parameters(), lr=args['lr'], eps=args['eps'])

total_steps = len(data_loaders[0]) * epoch_n 
scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps=0, num_training_steps = total_steps)

# get class weights 
class_weights = 1 - df[args['outcome']].value_counts(normalize=True).sort_index()
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

if write_file: 
    write_file.write(f'\nModel:\n {model}\nOptimizer:{optim}\n')

In [None]:
import pickle

In [None]:
args['save_path'] = 'ViT_mosiaks_model.pkl'

In [None]:
args['lr'] = 1e-5

In [None]:
train(args, data_loaders, epoch_n, model, optim, scheduler, criterion, metadata, device, write_file)

if write_file:
    write_file.close()


>>> Training starts...
*** Epoch: 0


  0%|          | 0/97 [00:00<?, ?it/s]

*** Average Loss: 1.119106650352478
*** Running accuracy on the train set: 0.078125



  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [05:11<1:22:56, 311.00s/it][A
 12%|█▏        | 2/17 [05:12<32:15, 129.03s/it]  [A
 18%|█▊        | 3/17 [05:14<16:32, 70.89s/it] [A
 24%|██▎       | 4/17 [05:15<09:26, 43.54s/it][A
 29%|██▉       | 5/17 [05:17<05:41, 28.42s/it][A
 35%|███▌      | 6/17 [05:19<03:32, 19.29s/it][A
 41%|████      | 7/17 [05:21<02:16, 13.60s/it][A
 47%|████▋     | 8/17 [05:22<01:27,  9.77s/it][A
 53%|█████▎    | 9/17 [05:24<00:58,  7.33s/it][A
 59%|█████▉    | 10/17 [05:26<00:38,  5.56s/it][A
 65%|██████▍   | 11/17 [05:28<00:26,  4.47s/it][A
 71%|███████   | 12/17 [05:29<00:17,  3.59s/it][A
 76%|███████▋  | 13/17 [05:31<00:12,  3.00s/it][A
 82%|████████▏ | 14/17 [05:33<00:07,  2.60s/it][A
 88%|████████▊ | 15/17 [05:35<00:04,  2.45s/it][A
 94%|█████████▍| 16/17 [05:37<00:02,  2.43s/it][A
100%|██████████| 17/17 [05:39<00:00, 19.94s/it]


*** Accuracy on the Validation set: 0.17988929889298894
*** Weighted accuracy on the Validation set: 0.36829457364341084
*** Confusion matrix:
[[ 57.   0.  43.]
 [365.   0. 361.]
 [120.   0. 138.]]


100%|██████████| 97/97 [11:23<00:00,  7.04s/it]


*** Epoch: 1


  0%|          | 0/97 [00:00<?, ?it/s]

*** Average Loss: 0.8845539093017578
*** Running accuracy on the train set: 0.609375



  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:27,  1.70s/it][A
 12%|█▏        | 2/17 [00:03<00:25,  1.72s/it][A
 18%|█▊        | 3/17 [00:05<00:24,  1.76s/it][A
 24%|██▎       | 4/17 [00:06<00:22,  1.73s/it][A
 29%|██▉       | 5/17 [00:08<00:20,  1.71s/it][A
 35%|███▌      | 6/17 [00:10<00:18,  1.71s/it][A
 41%|████      | 7/17 [00:12<00:17,  1.71s/it][A
 47%|████▋     | 8/17 [00:13<00:15,  1.71s/it][A
 53%|█████▎    | 9/17 [00:15<00:13,  1.71s/it][A
 59%|█████▉    | 10/17 [00:17<00:11,  1.71s/it][A
 65%|██████▍   | 11/17 [00:18<00:10,  1.72s/it][A
 71%|███████   | 12/17 [00:20<00:08,  1.73s/it][A
 76%|███████▋  | 13/17 [00:22<00:06,  1.72s/it][A
 82%|████████▏ | 14/17 [00:24<00:05,  1.72s/it][A
 88%|████████▊ | 15/17 [00:25<00:03,  1.72s/it][A
 94%|█████████▍| 16/17 [00:27<00:01,  1.72s/it][A
100%|██████████| 17/17 [00:29<00:00,  1.71s/it]


*** Accuracy on the Validation set: 0.6190036900369004
*** Weighted accuracy on the Validation set: 0.5148749652977983
*** Confusion matrix:
[[ 45.  54.   1.]
 [106. 533.  87.]
 [ 16. 149.  93.]]


100%|██████████| 97/97 [05:29<00:00,  3.40s/it]


*** Epoch: 2


  0%|          | 0/97 [00:00<?, ?it/s]

*** Average Loss: 0.766907811164856
*** Running accuracy on the train set: 0.578125



  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:27,  1.71s/it][A
 12%|█▏        | 2/17 [00:03<00:25,  1.72s/it][A
 18%|█▊        | 3/17 [00:05<00:24,  1.74s/it][A
 24%|██▎       | 4/17 [00:06<00:22,  1.72s/it][A
 29%|██▉       | 5/17 [00:08<00:20,  1.71s/it][A
 35%|███▌      | 6/17 [00:10<00:18,  1.71s/it][A
 41%|████      | 7/17 [00:11<00:17,  1.71s/it][A
 47%|████▋     | 8/17 [00:13<00:15,  1.71s/it][A
 53%|█████▎    | 9/17 [00:15<00:13,  1.70s/it][A
 59%|█████▉    | 10/17 [00:17<00:11,  1.69s/it][A
 65%|██████▍   | 11/17 [00:18<00:10,  1.69s/it][A
 71%|███████   | 12/17 [00:20<00:08,  1.69s/it][A
 76%|███████▋  | 13/17 [00:22<00:06,  1.70s/it][A
 82%|████████▏ | 14/17 [00:23<00:05,  1.70s/it][A
 88%|████████▊ | 15/17 [00:25<00:03,  1.71s/it][A
 94%|█████████▍| 16/17 [00:27<00:01,  1.71s/it][A
100%|██████████| 17/17 [00:28<00:00,  1.70s/it]
  1%|          | 1/97 [00:31<51:05, 31.93s/it]

*** Accuracy on the Validation set: 0.6116236162361623
*** Weighted accuracy on the Validation set: 0.5888696692079356
*** Confusion matrix:
[[ 50.  46.   4.]
 [102. 444. 180.]
 [ 15.  74. 169.]]


100%|██████████| 97/97 [05:27<00:00,  3.37s/it]


*** Epoch: 3


  0%|          | 0/97 [00:00<?, ?it/s]

*** Average Loss: 0.6523650288581848
*** Running accuracy on the train set: 0.703125



  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:27,  1.74s/it][A
 12%|█▏        | 2/17 [00:03<00:26,  1.74s/it][A
 18%|█▊        | 3/17 [00:05<00:24,  1.74s/it][A
 24%|██▎       | 4/17 [00:06<00:22,  1.73s/it][A
 29%|██▉       | 5/17 [00:08<00:20,  1.71s/it][A
 35%|███▌      | 6/17 [00:10<00:18,  1.71s/it][A
 41%|████      | 7/17 [00:12<00:17,  1.71s/it][A
 47%|████▋     | 8/17 [00:13<00:15,  1.71s/it][A
 53%|█████▎    | 9/17 [00:15<00:13,  1.70s/it][A
 59%|█████▉    | 10/17 [00:17<00:11,  1.70s/it][A
 65%|██████▍   | 11/17 [00:18<00:10,  1.71s/it][A
 71%|███████   | 12/17 [00:20<00:08,  1.72s/it][A
 76%|███████▋  | 13/17 [00:22<00:06,  1.71s/it][A
 82%|████████▏ | 14/17 [00:23<00:05,  1.71s/it][A
 88%|████████▊ | 15/17 [00:25<00:03,  1.71s/it][A
 94%|█████████▍| 16/17 [00:27<00:01,  1.71s/it][A
100%|██████████| 17/17 [00:28<00:00,  1.71s/it]
  1%|          | 1/97 [00:32<51:13, 32.02s/it]

*** Accuracy on the Validation set: 0.6070110701107011
*** Weighted accuracy on the Validation set: 0.5707497811091892
*** Confusion matrix:
[[ 50.  47.   3.]
 [113. 458. 155.]
 [ 14.  94. 150.]]


100%|██████████| 97/97 [05:25<00:00,  3.35s/it]


*** Epoch: 4


  0%|          | 0/97 [00:00<?, ?it/s]

*** Average Loss: 0.6430875062942505
*** Running accuracy on the train set: 0.78125



  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:27,  1.72s/it][A
 12%|█▏        | 2/17 [00:03<00:25,  1.73s/it][A
 18%|█▊        | 3/17 [00:05<00:24,  1.74s/it][A
 24%|██▎       | 4/17 [00:06<00:22,  1.71s/it][A
 29%|██▉       | 5/17 [00:08<00:20,  1.70s/it][A
 35%|███▌      | 6/17 [00:10<00:18,  1.70s/it][A
 41%|████      | 7/17 [00:11<00:16,  1.70s/it][A
 47%|████▋     | 8/17 [00:13<00:15,  1.70s/it][A
 53%|█████▎    | 9/17 [00:15<00:13,  1.69s/it][A
 59%|█████▉    | 10/17 [00:17<00:11,  1.70s/it][A
 65%|██████▍   | 11/17 [00:18<00:10,  1.71s/it][A
 71%|███████   | 12/17 [00:20<00:08,  1.72s/it][A
 76%|███████▋  | 13/17 [00:22<00:06,  1.72s/it][A
 82%|████████▏ | 14/17 [00:23<00:05,  1.72s/it][A
 88%|████████▊ | 15/17 [00:25<00:03,  1.72s/it][A
 94%|█████████▍| 16/17 [00:27<00:01,  1.73s/it][A
100%|██████████| 17/17 [00:28<00:00,  1.70s/it]


*** Accuracy on the Validation set: 0.6328413284132841
*** Weighted accuracy on the Validation set: 0.5838674269118244
*** Confusion matrix:
[[ 64.  36.   0.]
 [149. 520.  57.]
 [ 18. 138. 102.]]


100%|██████████| 97/97 [05:28<00:00,  3.39s/it]


*** Epoch: 5


  0%|          | 0/97 [00:00<?, ?it/s]

*** Average Loss: 0.5946272015571594
*** Running accuracy on the train set: 0.734375



  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:26,  1.69s/it][A
 12%|█▏        | 2/17 [00:03<00:25,  1.71s/it][A
 18%|█▊        | 3/17 [00:05<00:24,  1.72s/it][A
 24%|██▎       | 4/17 [00:06<00:22,  1.71s/it][A
 29%|██▉       | 5/17 [00:08<00:20,  1.71s/it][A
 35%|███▌      | 6/17 [00:10<00:18,  1.71s/it][A
 41%|████      | 7/17 [00:11<00:17,  1.71s/it][A
 47%|████▋     | 8/17 [00:13<00:15,  1.71s/it][A
 53%|█████▎    | 9/17 [00:15<00:13,  1.70s/it][A
 59%|█████▉    | 10/17 [00:17<00:11,  1.70s/it][A
 65%|██████▍   | 11/17 [00:18<00:10,  1.71s/it][A
 71%|███████   | 12/17 [00:20<00:08,  1.71s/it][A
 76%|███████▋  | 13/17 [00:22<00:06,  1.71s/it][A
 82%|████████▏ | 14/17 [00:23<00:05,  1.71s/it][A
 88%|████████▊ | 15/17 [00:25<00:03,  1.71s/it][A
 94%|█████████▍| 16/17 [00:27<00:01,  1.71s/it][A
100%|██████████| 17/17 [00:28<00:00,  1.70s/it]
  1%|          | 1/97 [00:31<51:06, 31.94s/it]

*** Accuracy on the Validation set: 0.6217712177121771
*** Weighted accuracy on the Validation set: 0.5503177653917611
*** Confusion matrix:
[[ 47.  52.   1.]
 [117. 500. 109.]
 [ 12. 119. 127.]]


100%|██████████| 97/97 [05:26<00:00,  3.37s/it]


*** Epoch: 6


  0%|          | 0/97 [00:00<?, ?it/s]

*** Average Loss: 0.5705587863922119
*** Running accuracy on the train set: 0.78125



  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:27,  1.70s/it][A
 12%|█▏        | 2/17 [00:03<00:25,  1.72s/it][A
 18%|█▊        | 3/17 [00:05<00:24,  1.73s/it][A
 24%|██▎       | 4/17 [00:06<00:22,  1.72s/it][A
 29%|██▉       | 5/17 [00:08<00:20,  1.72s/it][A
 35%|███▌      | 6/17 [00:10<00:18,  1.71s/it][A
 41%|████      | 7/17 [00:11<00:17,  1.71s/it][A
 47%|████▋     | 8/17 [00:13<00:15,  1.71s/it][A
 53%|█████▎    | 9/17 [00:15<00:13,  1.71s/it][A
 59%|█████▉    | 10/17 [00:17<00:11,  1.70s/it][A
 65%|██████▍   | 11/17 [00:18<00:10,  1.70s/it][A
 71%|███████   | 12/17 [00:20<00:08,  1.71s/it][A
 76%|███████▋  | 13/17 [00:22<00:06,  1.71s/it][A
 82%|████████▏ | 14/17 [00:23<00:05,  1.71s/it][A
 88%|████████▊ | 15/17 [00:25<00:03,  1.71s/it][A
 94%|█████████▍| 16/17 [00:27<00:01,  1.72s/it][A
100%|██████████| 17/17 [00:28<00:00,  1.70s/it]
  1%|          | 1/97 [00:31<51:03, 31.91s/it]

*** Accuracy on the Validation set: 0.5821033210332104
*** Weighted accuracy on the Validation set: 0.520776048006492
*** Confusion matrix:
[[ 32.  59.   9.]
 [ 84. 432. 210.]
 [  6.  85. 167.]]


100%|██████████| 97/97 [05:25<00:00,  3.36s/it]


*** Epoch: 7


  0%|          | 0/97 [00:00<?, ?it/s]

*** Average Loss: 0.3262180685997009
*** Running accuracy on the train set: 0.921875



  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:30,  1.93s/it][A
 12%|█▏        | 2/17 [00:03<00:26,  1.80s/it][A
 18%|█▊        | 3/17 [00:05<00:24,  1.76s/it][A
 24%|██▎       | 4/17 [00:07<00:22,  1.73s/it][A
 29%|██▉       | 5/17 [00:08<00:20,  1.71s/it][A
 35%|███▌      | 6/17 [00:10<00:18,  1.71s/it][A
 41%|████      | 7/17 [00:12<00:17,  1.71s/it][A
 47%|████▋     | 8/17 [00:13<00:15,  1.71s/it][A
 53%|█████▎    | 9/17 [00:15<00:13,  1.71s/it][A
 59%|█████▉    | 10/17 [00:17<00:11,  1.71s/it][A
 65%|██████▍   | 11/17 [00:18<00:10,  1.72s/it][A
 71%|███████   | 12/17 [00:20<00:08,  1.72s/it][A
 76%|███████▋  | 13/17 [00:22<00:06,  1.72s/it][A
 82%|████████▏ | 14/17 [00:24<00:05,  1.72s/it][A
 88%|████████▊ | 15/17 [00:25<00:03,  1.74s/it][A
 94%|█████████▍| 16/17 [00:27<00:01,  1.73s/it][A
100%|██████████| 17/17 [00:29<00:00,  1.72s/it]


*** Accuracy on the Validation set: 0.6485239852398524
*** Weighted accuracy on the Validation set: 0.5194422021483333
*** Confusion matrix:
[[ 38.  60.   2.]
 [ 98. 560.  68.]
 [  9. 144. 105.]]


100%|██████████| 97/97 [05:28<00:00,  3.39s/it]


*** Epoch: 8


  0%|          | 0/97 [00:00<?, ?it/s]

*** Average Loss: 0.28685882687568665
*** Running accuracy on the train set: 0.921875



  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:26,  1.69s/it][A
 12%|█▏        | 2/17 [00:03<00:25,  1.72s/it][A
 18%|█▊        | 3/17 [00:05<00:24,  1.74s/it][A
 24%|██▎       | 4/17 [00:06<00:22,  1.73s/it][A
 29%|██▉       | 5/17 [00:08<00:20,  1.73s/it][A
 35%|███▌      | 6/17 [00:10<00:19,  1.73s/it][A
 41%|████      | 7/17 [00:12<00:17,  1.73s/it][A
 47%|████▋     | 8/17 [00:13<00:15,  1.73s/it][A
 53%|█████▎    | 9/17 [00:15<00:13,  1.73s/it][A
 59%|█████▉    | 10/17 [00:17<00:12,  1.74s/it][A
 65%|██████▍   | 11/17 [00:19<00:10,  1.74s/it][A
 71%|███████   | 12/17 [00:20<00:08,  1.73s/it][A
 76%|███████▋  | 13/17 [00:22<00:06,  1.74s/it][A
 82%|████████▏ | 14/17 [00:24<00:05,  1.73s/it][A
 88%|████████▊ | 15/17 [00:25<00:03,  1.73s/it][A
 94%|█████████▍| 16/17 [00:27<00:01,  1.72s/it][A
100%|██████████| 17/17 [00:29<00:00,  1.72s/it]
  1%|          | 1/97 [00:32<51:35, 32.25s/it]

*** Accuracy on the Validation set: 0.6254612546125461
*** Weighted accuracy on the Validation set: 0.5182843231469024
*** Confusion matrix:
[[ 30.  65.   5.]
 [ 77. 503. 146.]
 [  5. 108. 145.]]


100%|██████████| 97/97 [05:26<00:00,  3.36s/it]


*** Epoch: 9


  0%|          | 0/97 [00:00<?, ?it/s]

*** Average Loss: 0.171013742685318
*** Running accuracy on the train set: 1.0



  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:27,  1.73s/it][A
 12%|█▏        | 2/17 [00:03<00:26,  1.76s/it][A
 18%|█▊        | 3/17 [00:05<00:24,  1.76s/it][A
 24%|██▎       | 4/17 [00:07<00:24,  1.90s/it][A
 29%|██▉       | 5/17 [00:09<00:23,  1.92s/it][A
 35%|███▌      | 6/17 [00:11<00:20,  1.91s/it][A
 41%|████      | 7/17 [00:12<00:18,  1.84s/it][A
 47%|████▋     | 8/17 [00:14<00:16,  1.81s/it][A
 53%|█████▎    | 9/17 [00:16<00:14,  1.77s/it][A
 59%|█████▉    | 10/17 [00:18<00:12,  1.75s/it][A
 65%|██████▍   | 11/17 [00:19<00:10,  1.74s/it][A
 71%|███████   | 12/17 [00:21<00:08,  1.74s/it][A
 76%|███████▋  | 13/17 [00:23<00:06,  1.72s/it][A
 82%|████████▏ | 14/17 [00:24<00:05,  1.72s/it][A
 88%|████████▊ | 15/17 [00:26<00:03,  1.71s/it][A
 94%|█████████▍| 16/17 [00:28<00:01,  1.70s/it][A
100%|██████████| 17/17 [00:29<00:00,  1.76s/it]


*** Accuracy on the Validation set: 0.6494464944649446
*** Weighted accuracy on the Validation set: 0.4814914472419758
*** Confusion matrix:
[[ 20.  76.   4.]
 [ 56. 563. 107.]
 [  3. 134. 121.]]


100%|██████████| 97/97 [05:28<00:00,  3.39s/it]


In [None]:
train(args, data_loaders, epoch_n, model, optim, scheduler, criterion, metadata, device, write_file)

if write_file:
    write_file.close()


>>> Training starts...
*** Epoch: 0


  0%|          | 0/97 [00:00<?, ?it/s]

*** Average Loss: 1.1114825010299683
*** Running accuracy on the train set: 0.328125



  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:27<07:23, 27.74s/it][A
 12%|█▏        | 2/17 [00:29<03:03, 12.24s/it][A
 18%|█▊        | 3/17 [00:30<01:42,  7.30s/it][A
 24%|██▎       | 4/17 [00:31<01:04,  4.98s/it][A
 29%|██▉       | 5/17 [00:33<00:44,  3.70s/it][A
 35%|███▌      | 6/17 [00:34<00:32,  2.93s/it][A
 41%|████      | 7/17 [00:36<00:24,  2.45s/it][A
 47%|████▋     | 8/17 [00:37<00:19,  2.15s/it][A
 53%|█████▎    | 9/17 [00:39<00:16,  2.07s/it][A
 59%|█████▉    | 10/17 [00:41<00:13,  1.89s/it][A
 65%|██████▍   | 11/17 [00:42<00:10,  1.76s/it][A
 71%|███████   | 12/17 [00:44<00:08,  1.67s/it][A
 76%|███████▋  | 13/17 [00:45<00:06,  1.61s/it][A
 82%|████████▏ | 14/17 [00:46<00:04,  1.54s/it][A
 88%|████████▊ | 15/17 [00:48<00:02,  1.48s/it][A
 94%|█████████▍| 16/17 [00:49<00:01,  1.47s/it][A
100%|██████████| 17/17 [00:51<00:00,  3.01s/it]


*** Accuracy on the Validation set: 0.3284132841328413
*** Weighted accuracy on the Validation set: 0.4130311455612661
*** Confusion matrix:
[[ 48.  34.  38.]
 [165. 152. 398.]
 [ 43.  50. 156.]]


100%|██████████| 97/97 [05:34<00:00,  3.44s/it]


*** Epoch: 1


  0%|          | 0/97 [00:00<?, ?it/s]

*** Average Loss: 0.6878499388694763
*** Running accuracy on the train set: 0.734375



  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:24,  1.55s/it][A
 12%|█▏        | 2/17 [00:03<00:23,  1.57s/it][A
 18%|█▊        | 3/17 [00:04<00:22,  1.58s/it][A
 24%|██▎       | 4/17 [00:06<00:20,  1.59s/it][A
 29%|██▉       | 5/17 [00:07<00:19,  1.59s/it][A
 35%|███▌      | 6/17 [00:09<00:17,  1.58s/it][A
 41%|████      | 7/17 [00:11<00:15,  1.58s/it][A
 47%|████▋     | 8/17 [00:12<00:14,  1.58s/it][A
 53%|█████▎    | 9/17 [00:14<00:12,  1.59s/it][A
 59%|█████▉    | 10/17 [00:15<00:11,  1.60s/it][A
 65%|██████▍   | 11/17 [00:17<00:09,  1.61s/it][A
 71%|███████   | 12/17 [00:19<00:08,  1.67s/it][A
 76%|███████▋  | 13/17 [00:21<00:07,  1.77s/it][A
 82%|████████▏ | 14/17 [00:23<00:05,  1.75s/it][A
 88%|████████▊ | 15/17 [00:24<00:03,  1.76s/it][A
 94%|█████████▍| 16/17 [00:26<00:01,  1.71s/it][A
100%|██████████| 17/17 [00:27<00:00,  1.63s/it]


*** Accuracy on the Validation set: 0.6494464944649446
*** Weighted accuracy on the Validation set: 0.5715982812368355
*** Confusion matrix:
[[ 60.  59.   1.]
 [ 83. 524. 108.]
 [ 13. 116. 120.]]


100%|██████████| 97/97 [05:10<00:00,  3.20s/it]


*** Epoch: 2


  0%|          | 0/97 [00:00<?, ?it/s]

*** Average Loss: 0.74836665391922
*** Running accuracy on the train set: 0.65625



  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:24,  1.53s/it][A
 12%|█▏        | 2/17 [00:03<00:23,  1.54s/it][A
 18%|█▊        | 3/17 [00:04<00:21,  1.56s/it][A
 24%|██▎       | 4/17 [00:06<00:20,  1.58s/it][A
 29%|██▉       | 5/17 [00:07<00:18,  1.58s/it][A
 35%|███▌      | 6/17 [00:09<00:17,  1.58s/it][A
 41%|████      | 7/17 [00:11<00:15,  1.59s/it][A
 47%|████▋     | 8/17 [00:12<00:14,  1.57s/it][A
 53%|█████▎    | 9/17 [00:14<00:12,  1.57s/it][A
 59%|█████▉    | 10/17 [00:15<00:11,  1.58s/it][A
 65%|██████▍   | 11/17 [00:17<00:09,  1.58s/it][A
 71%|███████   | 12/17 [00:18<00:07,  1.58s/it][A
 76%|███████▋  | 13/17 [00:20<00:06,  1.58s/it][A
 82%|████████▏ | 14/17 [00:22<00:04,  1.57s/it][A
 88%|████████▊ | 15/17 [00:23<00:03,  1.56s/it][A
 94%|█████████▍| 16/17 [00:25<00:01,  1.56s/it][A
100%|██████████| 17/17 [00:26<00:00,  1.56s/it]


*** Accuracy on the Validation set: 0.6614391143911439
*** Weighted accuracy on the Validation set: 0.5689791801237584
*** Confusion matrix:
[[ 57.  61.   2.]
 [ 80. 542.  93.]
 [ 17. 114. 118.]]


100%|██████████| 97/97 [05:08<00:00,  3.18s/it]


*** Epoch: 3


  0%|          | 0/97 [00:00<?, ?it/s]

*** Average Loss: 0.644568145275116
*** Running accuracy on the train set: 0.75



  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:24,  1.52s/it][A
 12%|█▏        | 2/17 [00:03<00:22,  1.52s/it][A
 18%|█▊        | 3/17 [00:04<00:21,  1.55s/it][A
 24%|██▎       | 4/17 [00:06<00:20,  1.56s/it][A
 29%|██▉       | 5/17 [00:07<00:18,  1.56s/it][A
 35%|███▌      | 6/17 [00:09<00:17,  1.60s/it][A
 41%|████      | 7/17 [00:11<00:15,  1.59s/it][A
 47%|████▋     | 8/17 [00:12<00:14,  1.58s/it][A
 53%|█████▎    | 9/17 [00:14<00:12,  1.57s/it][A
 59%|█████▉    | 10/17 [00:15<00:11,  1.58s/it][A
 65%|██████▍   | 11/17 [00:17<00:09,  1.57s/it][A
 71%|███████   | 12/17 [00:18<00:07,  1.56s/it][A
 76%|███████▋  | 13/17 [00:20<00:06,  1.56s/it][A
 82%|████████▏ | 14/17 [00:21<00:04,  1.56s/it][A
 88%|████████▊ | 15/17 [00:23<00:03,  1.55s/it][A
 94%|█████████▍| 16/17 [00:24<00:01,  1.54s/it][A
100%|██████████| 17/17 [00:26<00:00,  1.55s/it]
  1%|          | 1/97 [00:29<46:43, 29.21s/it]

*** Accuracy on the Validation set: 0.6383763837638377
*** Weighted accuracy on the Validation set: 0.5990228981192837
*** Confusion matrix:
[[ 69.  47.   4.]
 [ 97. 489. 129.]
 [ 18.  97. 134.]]


100%|██████████| 97/97 [05:05<00:00,  3.15s/it]


*** Epoch: 4


  0%|          | 0/97 [00:00<?, ?it/s]

*** Average Loss: 0.7508125901222229
*** Running accuracy on the train set: 0.671875



  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:24,  1.52s/it][A
 12%|█▏        | 2/17 [00:03<00:23,  1.55s/it][A
 18%|█▊        | 3/17 [00:04<00:21,  1.55s/it][A
 24%|██▎       | 4/17 [00:06<00:20,  1.56s/it][A
 29%|██▉       | 5/17 [00:07<00:18,  1.55s/it][A
 35%|███▌      | 6/17 [00:09<00:16,  1.54s/it][A
 41%|████      | 7/17 [00:10<00:15,  1.55s/it][A
 47%|████▋     | 8/17 [00:12<00:14,  1.56s/it][A
 53%|█████▎    | 9/17 [00:13<00:12,  1.56s/it][A
 59%|█████▉    | 10/17 [00:15<00:10,  1.57s/it][A
 65%|██████▍   | 11/17 [00:17<00:09,  1.57s/it][A
 71%|███████   | 12/17 [00:18<00:07,  1.56s/it][A
 76%|███████▋  | 13/17 [00:20<00:06,  1.56s/it][A
 82%|████████▏ | 14/17 [00:21<00:04,  1.55s/it][A
 88%|████████▊ | 15/17 [00:23<00:03,  1.54s/it][A
 94%|█████████▍| 16/17 [00:24<00:01,  1.54s/it][A
100%|██████████| 17/17 [00:26<00:00,  1.54s/it]
  1%|          | 1/97 [00:29<46:25, 29.02s/it]

*** Accuracy on the Validation set: 0.6134686346863468
*** Weighted accuracy on the Validation set: 0.5654609112440437
*** Confusion matrix:
[[ 52.  64.   4.]
 [ 79. 458. 178.]
 [ 13.  81. 155.]]


100%|██████████| 97/97 [05:06<00:00,  3.16s/it]


*** Epoch: 5


  0%|          | 0/97 [00:00<?, ?it/s]

*** Average Loss: 0.39360618591308594
*** Running accuracy on the train set: 0.953125



  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:24,  1.52s/it][A
 12%|█▏        | 2/17 [00:03<00:23,  1.55s/it][A
 18%|█▊        | 3/17 [00:04<00:21,  1.56s/it][A
 24%|██▎       | 4/17 [00:06<00:20,  1.56s/it][A
 29%|██▉       | 5/17 [00:07<00:18,  1.56s/it][A
 35%|███▌      | 6/17 [00:09<00:17,  1.56s/it][A
 41%|████      | 7/17 [00:10<00:15,  1.56s/it][A
 47%|████▋     | 8/17 [00:12<00:13,  1.55s/it][A
 53%|█████▎    | 9/17 [00:13<00:12,  1.55s/it][A
 59%|█████▉    | 10/17 [00:15<00:10,  1.57s/it][A
 65%|██████▍   | 11/17 [00:17<00:09,  1.55s/it][A
 71%|███████   | 12/17 [00:18<00:07,  1.55s/it][A
 76%|███████▋  | 13/17 [00:20<00:06,  1.55s/it][A
 82%|████████▏ | 14/17 [00:21<00:04,  1.54s/it][A
 88%|████████▊ | 15/17 [00:23<00:03,  1.53s/it][A
 94%|█████████▍| 16/17 [00:24<00:01,  1.54s/it][A
100%|██████████| 17/17 [00:26<00:00,  1.54s/it]


*** Accuracy on the Validation set: 0.6715867158671587
*** Weighted accuracy on the Validation set: 0.49093577105625297
*** Confusion matrix:
[[ 18. 100.   2.]
 [ 18. 584. 113.]
 [  2. 121. 126.]]


100%|██████████| 97/97 [05:08<00:00,  3.19s/it]


*** Epoch: 6


  0%|          | 0/97 [00:00<?, ?it/s]

*** Average Loss: 0.47175732254981995
*** Running accuracy on the train set: 0.8125



  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:24,  1.52s/it][A
 12%|█▏        | 2/17 [00:03<00:23,  1.55s/it][A
 18%|█▊        | 3/17 [00:04<00:21,  1.56s/it][A
 24%|██▎       | 4/17 [00:06<00:20,  1.57s/it][A
 29%|██▉       | 5/17 [00:07<00:18,  1.57s/it][A
 35%|███▌      | 6/17 [00:09<00:17,  1.57s/it][A
 41%|████      | 7/17 [00:10<00:15,  1.56s/it][A
 47%|████▋     | 8/17 [00:12<00:13,  1.55s/it][A
 53%|█████▎    | 9/17 [00:14<00:12,  1.56s/it][A
 59%|█████▉    | 10/17 [00:15<00:10,  1.56s/it][A
 65%|██████▍   | 11/17 [00:17<00:09,  1.56s/it][A
 71%|███████   | 12/17 [00:18<00:07,  1.55s/it][A
 76%|███████▋  | 13/17 [00:20<00:06,  1.54s/it][A
 82%|████████▏ | 14/17 [00:21<00:04,  1.53s/it][A
 88%|████████▊ | 15/17 [00:23<00:03,  1.53s/it][A
 94%|█████████▍| 16/17 [00:24<00:01,  1.53s/it][A
100%|██████████| 17/17 [00:26<00:00,  1.54s/it]
  1%|          | 1/97 [00:28<46:18, 28.94s/it]

*** Accuracy on the Validation set: 0.6023985239852399
*** Weighted accuracy on the Validation set: 0.5319120304060063
*** Confusion matrix:
[[ 35.  75.  10.]
 [ 48. 450. 217.]
 [  6.  75. 168.]]


100%|██████████| 97/97 [05:04<00:00,  3.14s/it]


*** Epoch: 7


  0%|          | 0/97 [00:00<?, ?it/s]

*** Average Loss: 0.3748129904270172
*** Running accuracy on the train set: 0.90625



  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:24,  1.51s/it][A
 12%|█▏        | 2/17 [00:03<00:23,  1.53s/it][A
 18%|█▊        | 3/17 [00:04<00:21,  1.54s/it][A
 24%|██▎       | 4/17 [00:06<00:19,  1.53s/it][A
 29%|██▉       | 5/17 [00:07<00:18,  1.53s/it][A
 35%|███▌      | 6/17 [00:09<00:16,  1.53s/it][A
 41%|████      | 7/17 [00:10<00:15,  1.54s/it][A
 47%|████▋     | 8/17 [00:12<00:13,  1.54s/it][A
 53%|█████▎    | 9/17 [00:13<00:12,  1.54s/it][A
 59%|█████▉    | 10/17 [00:15<00:10,  1.55s/it][A
 65%|██████▍   | 11/17 [00:16<00:09,  1.55s/it][A
 71%|███████   | 12/17 [00:18<00:07,  1.54s/it][A
 76%|███████▋  | 13/17 [00:19<00:06,  1.54s/it][A
 82%|████████▏ | 14/17 [00:21<00:04,  1.53s/it][A
 88%|████████▊ | 15/17 [00:23<00:03,  1.53s/it][A
 94%|█████████▍| 16/17 [00:24<00:01,  1.53s/it][A
100%|██████████| 17/17 [00:25<00:00,  1.53s/it]
  1%|          | 1/97 [00:28<46:02, 28.77s/it]

*** Accuracy on the Validation set: 0.6605166051660517
*** Weighted accuracy on the Validation set: 0.5155622489959839
*** Confusion matrix:
[[ 39.  79.   2.]
 [ 54. 572.  89.]
 [  6. 138. 105.]]


100%|██████████| 97/97 [05:04<00:00,  3.13s/it]


*** Epoch: 8


  0%|          | 0/97 [00:00<?, ?it/s]

*** Average Loss: 0.24936728179454803
*** Running accuracy on the train set: 0.9375



  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:24,  1.51s/it][A
 12%|█▏        | 2/17 [00:03<00:22,  1.52s/it][A
 18%|█▊        | 3/17 [00:04<00:21,  1.53s/it][A
 24%|██▎       | 4/17 [00:06<00:20,  1.55s/it][A
 29%|██▉       | 5/17 [00:07<00:18,  1.54s/it][A
 35%|███▌      | 6/17 [00:09<00:16,  1.53s/it][A
 41%|████      | 7/17 [00:10<00:15,  1.54s/it][A
 47%|████▋     | 8/17 [00:12<00:13,  1.53s/it][A
 53%|█████▎    | 9/17 [00:13<00:12,  1.53s/it][A
 59%|█████▉    | 10/17 [00:15<00:10,  1.54s/it][A
 65%|██████▍   | 11/17 [00:16<00:09,  1.54s/it][A
 71%|███████   | 12/17 [00:18<00:07,  1.53s/it][A
 76%|███████▋  | 13/17 [00:19<00:06,  1.53s/it][A
 82%|████████▏ | 14/17 [00:21<00:04,  1.52s/it][A
 88%|████████▊ | 15/17 [00:22<00:03,  1.53s/it][A
 94%|█████████▍| 16/17 [00:24<00:01,  1.53s/it][A
100%|██████████| 17/17 [00:25<00:00,  1.53s/it]
  1%|          | 1/97 [00:28<45:55, 28.70s/it]

*** Accuracy on the Validation set: 0.6605166051660517
*** Weighted accuracy on the Validation set: 0.49375005850909465
*** Confusion matrix:
[[ 39.  79.   2.]
 [ 52. 597.  66.]
 [  5. 164.  80.]]


  3%|▎         | 3/97 [00:35<18:39, 11.91s/it]


KeyboardInterrupt: ignored

In [None]:
## MOSAIKS RESULTS

In [None]:
df = pd.read_csv("west_africa_df.csv")

In [None]:
indices = np.load("indices_perm2.npy")
n_val =int(np.floor(len(indices) * 0.15))

In [None]:
train_dhsids = pd.Series(np.array(
    data_loaders[0].dataset.dataset.image_paths
)[indices[:-n_val]]).apply(lambda x: x.split('.')[0])
train_df = df.loc[df["DHSID"].isin(train_dhsids)]
val_df = df.loc[~df["DHSID"].isin(train_dhsids)]

In [None]:
from sklearn.linear_model import LogisticRegression

In [None]:
mosaiks_df = pd.read_csv("west_africa_mosaiks_feats.csv")

In [None]:
train_mosaiks = mosaiks_df.loc[mosaiks_df['DHSID'].isin(train_dhsids)]
val_mosaiks = mosaiks_df.loc[~mosaiks_df['DHSID'].isin(train_dhsids)]

In [None]:
mosaiks_feats = [" ." + str(i + 1) for i in range(3999)]

In [None]:
train_mosaiks_fts = train_mosaiks[mosaiks_feats]
val_mosaiks_fts = val_mosaiks[mosaiks_feats]

In [None]:
# MEAN BMI BIN
lr = LogisticRegression(random_state = 231, class_weight = 'balanced', max_iter = 1000, tol = 0.001)
lr.fit(train_mosaiks_fts, train_df["Mean_BMI_bin"])
preds = lr.predict(val_mosaiks_fts)
cm = confusion_matrix(val_df["Mean_BMI_bin"], preds, labels=CLASS_NAMES)
np.mean(
    np.diag(cm) / np.sum(cm, 1)
)

In [None]:
cm = confusion_matrix(val_df["Mean_BMI_bin"], preds, labels=CLASS_NAMES)
np.mean(
    np.diag(cm) / np.sum(cm, 1)
)

0.46771063281824876

In [None]:
# UNDER 5 BIN
u5_lr = LogisticRegression(random_state = 231, class_weight = 'balanced', max_iter = 2000)
u5_lr.fit(train_mosaiks_fts, train_df["Under5_Mortality_Rate_bin"])
preds = u5_lr.predict(val_mosaiks_fts)
cm = confusion_matrix(val_df["Under5_Mortality_Rate_bin"], preds, labels=CLASS_NAMES)
np.mean(
    np.diag(cm) / np.sum(cm, 1)
)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


0.3884044720511475

In [None]:
# UNDER 5 BIN QUINT
u5q_lr = LogisticRegression(random_state = 231, class_weight = 'balanced', max_iter = 4000)
u5q_lr.fit(train_mosaiks_fts, train_df["Under5_Mortality_Rate_bin_quint"])
preds = u5q_lr.predict(val_mosaiks_fts)
cm = confusion_matrix(val_df["Under5_Mortality_Rate_bin_quint"], preds)
np.mean(
    np.diag(cm) / np.sum(cm, 1)
)

0.24587755444298157

In [None]:
# MEAN BMI QUINT
bmiq_lr = LogisticRegression(random_state = 231, class_weight = 'balanced', max_iter = 4000)
bmiq_lr.fit(train_mosaiks_fts, train_df["Mean_BMI_bin_quint"])
preds = bmiq_lr.predict(val_mosaiks_fts)
cm = confusion_matrix(val_df["Mean_BMI_bin_quint"], preds)
np.mean(
    np.diag(cm) / np.sum(cm, 1)
)

0.25441700648433796

In [None]:
# UNDER 5 BIN QUINT
u5q_lr = LogisticRegression(random_state = 231, class_weight = 'balanced', max_iter = 4000, C = 0.1)
u5q_lr.fit(train_mosaiks_fts, train_df["Under5_Mortality_Rate_bin_quint"])
preds = u5q_lr.predict(val_mosaiks_fts)
cm = confusion_matrix(val_df["Under5_Mortality_Rate_bin_quint"], preds)
np.mean(
    np.diag(cm) / np.sum(cm, 1)
)

0.24932198275131584