In [1]:
import pandas as pd
import numpy as np

import cv2
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from tqdm import tqdm

import torch
from PIL import Image
from torchvision.transforms import v2
from torchvision import transforms

import os

from sklearn import preprocessing




In [2]:
backends = [
  'opencv', 
  'ssd', 
  'dlib', 
  'mtcnn', 
  'retinaface', 
  'mediapipe',
  'yolov8',
  'yunet',
  'fastmtcnn',
]

if torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'

In [3]:
data = pd.read_csv("/kaggle/input/fac-data-p1/labels.csv")
path = "/kaggle/input/fac-data-p1/data/mnt/md0/projects/sami-hackathon/private/data/"
cols = data.columns
y_raw = data[cols[4:]].values
files_list = data['file_name'].values

In [None]:
data.head()

# label preprocessing

In [4]:
labels_set = {}
for col in cols[4:]:
    temp = data[col].unique()
    labels_set[col] = temp
    
labels_set['age'] = ['Baby', 'Kid', 'Teenager', '20-30s', '40-50s', 'Senior']
print(labels_set)

labels = dict()
for col in cols[4:]:
    enc = preprocessing.OneHotEncoder(categories=[labels_set[col]]).fit(data[col].values.reshape(-1 ,1))
    labels[col] = enc.transform(data[col].values.reshape(-1 ,1)).toarray()
    print(enc.categories_)

temp = data[['file_name', 'bbox']].values
file_names = temp[:, 0]
bboxs = temp[:, 1]
bbxs = []
for i in range(len(bboxs)):
    bbxs.append(eval(bboxs[i]))
bboxes = np.array(bbxs, int)
    

{'age': ['Baby', 'Kid', 'Teenager', '20-30s', '40-50s', 'Senior'], 'race': array(['Caucasian', 'Mongoloid', 'Negroid'], dtype=object), 'masked': array(['unmasked', 'masked'], dtype=object), 'skintone': array(['mid-light', 'light', 'mid-dark', 'dark'], dtype=object), 'emotion': array(['Neutral', 'Happiness', 'Anger', 'Surprise', 'Fear', 'Sadness',
       'Disgust'], dtype=object), 'gender': array(['Male', 'Female'], dtype=object)}
[array(['Baby', 'Kid', 'Teenager', '20-30s', '40-50s', 'Senior'],
      dtype=object)]
[array(['Caucasian', 'Mongoloid', 'Negroid'], dtype=object)]
[array(['unmasked', 'masked'], dtype=object)]
[array(['mid-light', 'light', 'mid-dark', 'dark'], dtype=object)]
[array(['Neutral', 'Happiness', 'Anger', 'Surprise', 'Fear', 'Sadness',
       'Disgust'], dtype=object)]
[array(['Male', 'Female'], dtype=object)]


In [5]:
weights = {}
 
for col in cols[4:]:
    #print(col)
    temp = data[col].value_counts()
    #print(temp[labels_set[col]])
    temp = temp[labels_set[col]].values
    classes = len(temp)
    n = np.sum(temp)
    
    temp = n / (classes * temp)
    
    weights[col] = temp

weights


{'age': array([7.39613527, 2.674703  , 4.76057214, 0.22709742, 1.59280067,
        4.00575615]),
 'race': array([0.71817244, 0.68162593, 7.11761971]),
 'masked': array([ 0.51702013, 15.18849206]),
 'skintone': array([ 1.03782538,  0.3650453 ,  4.79636591, 11.29056047]),
 'emotion': array([ 0.45151587,  0.2372687 ,  6.8562472 ,  7.21829326, 19.18546366,
         5.7556391 , 16.56926407]),
 'gender': array([1.59878864, 0.72752328])}

# Data loader

In [7]:

mytransform2 = transforms.Compose([
            #transforms.RandomHorizontalFlip(),
            #transforms.Resize((224, 224)),
            transforms.ToTensor(),  # mmb
        ]
        )
class Dataset2(torch.utils.data.Dataset):

    def __init__(self, data, labels, root):
        self.labels = labels
        self.data = data
        self.root = root


    def __len__(self):
        return len(self.data)


    def __getitem__(self, index):
        X = self.data[index]
        y = dict()
        y["age"] = labels['age'][index]
        y["race"] = labels['race'][index]
        y["masked"] = labels['masked'][index]
        y["skintone"] = labels['skintone'][index]
        y["emotion"] = labels['emotion'][index]
        y["gender"] = labels['gender'][index]

        X = mytransform2(X)

        return X, y

    
mytransform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),  # mmb
        ]
        )


In [8]:
from torch.utils.data import random_split

data_preprocessed = np.load('/kaggle/input/face-aligned/np_images.npy')

root = "/kaggle/input/pixta-train-face/images_face"
data_set = Dataset2(data_preprocessed, labels, root)

generator1 = torch.Generator().manual_seed(42)
train_set, test_set = random_split(data_set, [0.8, 0.2], generator = generator1)

# Model

In [9]:
import torch.nn.functional as F

class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)

class InceptionBlock(nn.Module):
    def __init__(
        self, 
        in_channels, 
        out_1x1,
        red_3x3,
        out_3x3,
        red_5x5,
        out_5x5,
        out_pool,
    ):
        super(InceptionBlock, self).__init__()
        self.branch1 = BasicConv2d(in_channels, out_1x1, kernel_size=1)
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, red_3x3, kernel_size=1, padding=0),
            BasicConv2d(red_3x3, out_3x3, kernel_size=3, padding=1),
        )
        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, red_5x5, kernel_size=1),
            BasicConv2d(red_5x5, out_5x5, kernel_size=5, padding=2),
        )
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, padding=1, stride=1),
            BasicConv2d(in_channels, out_pool, kernel_size=1),
        )
    
    def forward(self, x):
        branches = (self.branch1, self.branch2, self.branch3, self.branch4)
        return torch.cat([branch(x) for branch in branches], 1)

    
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
           
        self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
                               nn.ReLU(),
                               nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

    
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)
    



In [10]:
class SkintoneRaceModel(nn.Module):
    def __init__(self):
        super(SkintoneRaceModel, self).__init__()
        self.conv5x5x48_1 = BasicConv2d(3, 48, kernel_size = (5, 5), stride = 2, padding = 2)
        self.conv3x3x96_1 = BasicConv2d(48, 96, kernel_size = (3, 3), stride = 1, padding = 1)
        self.conv3x3x192_1 = BasicConv2d(96, 192, kernel_size = (3, 3), stride = 1, padding = 1)
        self.conv3x3x384 = BasicConv2d(192, 384, kernel_size = (3, 3), stride = 1, padding = 1)
        
        self.conv5x5x48 = BasicConv2d(3, 48, kernel_size = (5, 5), stride = 2, padding = 1)
        self.conv3x3x96 = BasicConv2d(48, 96, kernel_size = (3, 3), stride = 1, padding = 1)
        self.conv3x3x192 = BasicConv2d(96, 192, kernel_size = (3, 3), stride = 1, padding = 1)
        
        self.inception1 = InceptionBlock(192, 128, 128, 256, 128, 256, 128)
        self.conv3x3x512 = BasicConv2d(128 + 256 + 256 + 128, 512, kernel_size = (3, 3), stride = 1, padding = 1)
        
        self.gap = torch.nn.AdaptiveAvgPool2d(1)
        
        self.fc_race = nn.Linear(512, 3)
        self.fc_skintone = nn.Linear(384, 4)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        x_skin = self.conv5x5x48_1(x)
        x_skin = nn.functional.max_pool2d(x_skin, (2, 2), 2)
        
        x_skin = self.conv3x3x96_1(x_skin)
        x_skin = nn.functional.max_pool2d(x_skin, (2, 2), 2)
        
        x_skin = self.conv3x3x192_1(x_skin)
        x_skin = nn.functional.max_pool2d(x_skin, (2, 2), 2)
        
        x_skin = self.conv3x3x384(x_skin)
        x_skin = nn.functional.max_pool2d(x_skin, (2, 2), 2)
        
        N, C, W, H = x_skin.shape
        x_skin = self.gap(x_skin).view(N, -1)
        x_skin = self.dropout(x_skin)
        x_skin = self.fc_skintone(x_skin)
        x_skin = nn.functional.softmax(x_skin, dim = 1)
            
        #  RACE
        x_race = self.conv5x5x48(x)
        x_race = nn.functional.max_pool2d(x_race, (2, 2), 2)
        
        x_race = self.conv3x3x96(x_race)
        x_race = nn.functional.max_pool2d(x_race, (2, 2), 2)
        
        x_race = self.conv3x3x192(x_race)
        x_race = nn.functional.max_pool2d(x_race, (2, 2), 2)
        
        x_race = self.inception1(x_race)
        x_race = nn.functional.max_pool2d(x_race, (2, 2), 2)
        # 1152 x 14 x 14
        
        x_race = self.conv3x3x512(x_race)
        x_race = nn.functional.max_pool2d(x_race, (2, 2), 2)
        
        N, C, W, H = x_race.shape
        #print(x.shape)
        #x = x.view(N, -1)
        x_race = self.gap(x_race).view(N, -1)
        x_race = self.dropout(x_race)
        x_race = self.fc_race(x_race)
        x_race = nn.functional.softmax(x_race, dim = 1)
        
        return {"race": x_race, "skintone": x_skin}

In [11]:
class MaskedModel(nn.Module):
    def __init__(self,  att_out = False):
        super(MaskedModel, self).__init__()
        self.att_out = att_out
        self.conv5x5x3 = BasicConv2d(3, 3, kernel_size = (5, 5), stride = 1, padding = 2)
        self.conv3x3x96 = BasicConv2d(3, 96, kernel_size = (3, 3), stride = 1, padding = 1)
        self.conv3x3x192 = BasicConv2d(96, 192, kernel_size = (3, 3), stride = 1, padding = 1)
        self.conv3x3x256 = BasicConv2d(192, 256, kernel_size = (3, 3), stride = 1, padding = 1)
        self.inception1 = InceptionBlock(256, 192, 128, 384, 96, 384, 192)
        
        self.spatial_module = SpatialAttention()
        
        self.gap = torch.nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(192 + 384 + 384 + 192, 2)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        #x: 3 x 224 x 224
        x = self.conv5x5x3(x)
        x = nn.functional.max_pool2d(x, (2, 2), 2)
        
        x = self.conv3x3x96(x)
        x = nn.functional.max_pool2d(x, (2, 2), 2)
        
        x = self.conv3x3x192(x)
        x = nn.functional.max_pool2d(x, (2, 2), 2)
        
        x = self.conv3x3x256(x)
        x = nn.functional.max_pool2d(x, (2, 2), 2)
        
        x = self.inception1(x)
        x = nn.functional.max_pool2d(x, (2, 2), 2)
        
        N, C, W, H = x.shape
        sp_att = self.spatial_module(x).view(-1, 1, W, H)

        sp_att2 = sp_att.expand(-1, C, W, H)
        x = x * sp_att2
        
        # 896 x 7 x 7
        x = self.gap(x).view(-1, C)
        x = self.dropout(x)
        
        x = self.fc(x)
        
        x = nn.functional.softmax(x, dim = 1)
        
        if self.att_out:
            return {"masked": x}, sp_att
        
        return {"masked": x}

In [12]:
class GenderModel(nn.Module):
    def __init__(self):
        super(GenderModel, self).__init__()
        self.conv5x5x3 = BasicConv2d(3, 3, kernel_size = (5, 5), stride = 1, padding = 2)
        self.conv3x3x96 = BasicConv2d(3, 96, kernel_size = (3, 3), stride = 1, padding = 1)
        self.conv3x3x192 = BasicConv2d(96, 192, kernel_size = (3, 3), stride = 1, padding = 1)
        self.conv3x3x384 = BasicConv2d(192, 384, kernel_size = (3, 3), stride = 1, padding = 1)
        self.inception1 = InceptionBlock(384, 192, 128, 384, 128, 384, 192)
        
        self.gap = torch.nn.AdaptiveAvgPool2d(1)
        
        self.fc = nn.Linear(192 + 384 + 384 + 192, 2)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        #x: 3 x 224 x 224
        x = self.conv5x5x3(x)
        x = nn.functional.max_pool2d(x, (2, 2), 2)
        
        x = self.conv3x3x96(x)
        x = nn.functional.max_pool2d(x, (2, 2), 2)
        
        x = self.conv3x3x192(x)
        x = nn.functional.max_pool2d(x, (2, 2), 2)
        
        x = self.conv3x3x384(x)
        x = nn.functional.max_pool2d(x, (2, 2), 2)
        
        x = self.inception1(x)
        x = nn.functional.max_pool2d(x, (2, 2), 2)
        
        N, C, W, H = x.shape
        # 896 x 7 x 7
        x = self.gap(x).view(N, C)
        x = self.dropout(x)
        
        x = self.fc(x)
        
        x = nn.functional.softmax(x, dim = 1)
           
        return {"gender": x}

In [13]:
masked_model = MaskedModel(False)

In [14]:
skin_race_model = SkintoneRaceModel()

In [15]:
gender_model = GenderModel()

In [16]:
import os
def checkpoint_save(name, state_dict, epoch):
    if not os.path.exists("checkpoint"):
        os.mkdir("checkpoint")
    save_path = "./checkpoint/" + name + "_epoch_{}".format(epoch)
    torch.save(state_dict, save_path)


In [17]:
training_generator = torch.utils.data.DataLoader(train_set, batch_size = 16, shuffle = False)
testing_generator = torch.utils.data.DataLoader(test_set, batch_size = 1, shuffle = False)

# Train function

In [18]:
def convert_result(y_pred):

    results = []
    for i in range(len(y_pred)):
        temp = y_pred[i] > 0.5
        temp = temp.astype(int)
        
        results.append(temp )

    return np.array(results )

In [19]:
import time

def test_model(model):
    model = model.to(device)
    criterion_crossEntropy = torch.nn.CrossEntropyLoss().to(device)
    
    start_time = time.time()

    
    test_loss = 0
    model.eval()
    y_predict = {'age': [], 'gender': [], 'masked': [], 'skintone': [], 'race': [], 'emotion': []}
    y_true = {'age': [], 'gender': [], 'masked': [], 'skintone': [], 'race': [], 'emotion': []}
    
    with torch.no_grad() :
        for b, (X, y) in tqdm(enumerate(testing_generator), total=len(testing_generator)):
            y_pred = model(X.to(device))
            
            loss = 0
            for key in y_pred.keys():
                loss += criterion_crossEntropy(y_pred[key], y[key].to(device))

                test_loss += loss.item()
                
                results = convert_result(y_pred[key].cpu().numpy())
       
                y_predict[key].extend(results)
                y_true[key].extend(y[key].numpy())
            
        print(f'\nDuration: {time.time() - start_time:.0f} seconds')  # print the time elapsed
    return test_loss, y_true, y_predict



In [20]:
def train_model(model, epochs = 10, steps = 256, checkpoint = None, test = True):
    start_time = time.time()
    
    model = model.to(device)
    criterion_crossEntropy = torch.nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    running_loss = 0
    train_acc, train_loss = [], [0] * epochs
    val_loss, val_acc = [0] * epochs, []

    for i in range(epochs):

        result_epoch = []
        #train
        model.train()
        for b, (X, y) in tqdm(enumerate(training_generator), total=steps):
            optimizer.zero_grad()

            b += 1
            if b > steps:
                break
            y_pred = model(X.to(device))
            #print(y_pred)
            #print(y)
            loss = 0
            for key in y_pred.keys():
                #print(key)
                #print(y_pred[key], y[key])
                loss += F.cross_entropy(y_pred[key], y[key].to(device), \
                                        weight = torch.Tensor(weights[key]).to(device) ) 
            
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            train_loss[i] += loss.item()

            if b % 50 == 0:

                print(f"Loss: {running_loss / 50} ", )

                running_loss = 0
        train_acc.append(np.sum(result_epoch) / np.array(result_epoch).size)
        if test == True:
            test_loss, y_true, y_predict = test_model(model)
            
            print(f"Test Loss: {test_loss} ", )
            
        if (i + 1) % 5 == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] = param_group['lr'] * 0.95
            if checkpoint is not None:
                checkpoint_save(checkpoint, model.state_dict(), i)

        #if save_checkpoint:
        #    checkpoint_save(name, model.state_dict(), i)

    print(f'\nDuration: {time.time() - start_time:.0f} seconds')  # print the time elapsed
    
    return train_loss


# Train

In [22]:
losses_gender = train_model(gender_model, epochs = 50, checkpoint = "gender_model", test = False)

 21%|██        | 54/256 [00:01<00:07, 28.09it/s]

Loss: 0.5758844950952136 


 41%|████      | 105/256 [00:03<00:05, 27.62it/s]

Loss: 0.5657463795849018 


 60%|█████▉    | 153/256 [00:05<00:03, 27.68it/s]

Loss: 0.537407907161753 


 80%|███████▉  | 204/256 [00:07<00:01, 28.07it/s]

Loss: 0.5428186549742776 


100%|█████████▉| 255/256 [00:09<00:00, 27.45it/s]

Loss: 0.5215474879351486 


100%|██████████| 256/256 [00:09<00:00, 27.82it/s]
  train_acc.append(np.sum(result_epoch) / np.array(result_epoch).size)
 21%|██        | 54/256 [00:01<00:07, 27.96it/s]

Loss: 0.5798959765686998 


 41%|████      | 105/256 [00:03<00:05, 27.93it/s]

Loss: 0.5205630456133877 


 60%|█████▉    | 153/256 [00:05<00:03, 28.03it/s]

Loss: 0.4999451432458652 


 80%|███████▉  | 204/256 [00:07<00:01, 28.01it/s]

Loss: 0.5123112591960075 


100%|█████████▉| 255/256 [00:09<00:00, 27.79it/s]

Loss: 0.4770193128504057 


100%|██████████| 256/256 [00:09<00:00, 27.95it/s]
 21%|██        | 54/256 [00:01<00:07, 28.03it/s]

Loss: 0.5470949377097295 


 41%|████      | 105/256 [00:03<00:05, 27.65it/s]

Loss: 0.48805732456450945 


 60%|█████▉    | 153/256 [00:05<00:03, 27.67it/s]

Loss: 0.48363422364431663 


 80%|███████▉  | 204/256 [00:07<00:01, 28.03it/s]

Loss: 0.4913844267095993 


100%|█████████▉| 255/256 [00:09<00:00, 27.84it/s]

Loss: 0.4574710768545842 


100%|██████████| 256/256 [00:09<00:00, 27.75it/s]
 21%|██        | 54/256 [00:01<00:07, 27.58it/s]

Loss: 0.530395306846049 


 41%|████      | 105/256 [00:03<00:05, 27.63it/s]

Loss: 0.4740428166706419 


 60%|█████▉    | 153/256 [00:05<00:03, 27.92it/s]

Loss: 0.4588261046594345 


 80%|███████▉  | 204/256 [00:07<00:01, 27.83it/s]

Loss: 0.47894772348482934 


100%|█████████▉| 255/256 [00:09<00:00, 27.99it/s]

Loss: 0.4380329147531512 


100%|██████████| 256/256 [00:09<00:00, 27.82it/s]
 21%|██        | 54/256 [00:01<00:07, 27.81it/s]

Loss: 0.50939580553755 


 41%|████      | 105/256 [00:03<00:05, 27.55it/s]

Loss: 0.45559862624734676 


 60%|█████▉    | 153/256 [00:05<00:03, 27.90it/s]

Loss: 0.45691374306376487 


 80%|███████▉  | 204/256 [00:07<00:01, 27.36it/s]

Loss: 0.47376222588785 


100%|█████████▉| 255/256 [00:09<00:00, 27.95it/s]

Loss: 0.4286936852018977 


100%|██████████| 256/256 [00:09<00:00, 27.84it/s]
 21%|██        | 54/256 [00:01<00:06, 28.91it/s]

Loss: 0.49408316106312694 


 41%|████      | 105/256 [00:03<00:05, 28.70it/s]

Loss: 0.43454532759461467 


 60%|█████▉    | 153/256 [00:05<00:03, 28.55it/s]

Loss: 0.4344087968700168 


 80%|███████▉  | 204/256 [00:07<00:01, 28.68it/s]

Loss: 0.46264071647131494 


100%|█████████▉| 255/256 [00:08<00:00, 27.91it/s]

Loss: 0.41528830728599453 


100%|██████████| 256/256 [00:08<00:00, 28.52it/s]
 21%|██        | 54/256 [00:01<00:06, 28.90it/s]

Loss: 0.48089137644188445 


 41%|████      | 105/256 [00:03<00:05, 28.93it/s]

Loss: 0.4200048675869056 


 60%|█████▉    | 153/256 [00:05<00:03, 28.38it/s]

Loss: 0.4356879954461256 


 80%|███████▉  | 204/256 [00:07<00:01, 28.82it/s]

Loss: 0.4520364806863288 


100%|█████████▉| 255/256 [00:08<00:00, 28.51it/s]

Loss: 0.41220044683110485 


100%|██████████| 256/256 [00:08<00:00, 28.58it/s]
 21%|██        | 54/256 [00:01<00:07, 28.73it/s]

Loss: 0.46920629808478764 


 41%|████      | 105/256 [00:03<00:05, 28.54it/s]

Loss: 0.4213347697472571 


 60%|█████▉    | 153/256 [00:05<00:03, 28.36it/s]

Loss: 0.4290803224370653 


 80%|███████▉  | 204/256 [00:07<00:01, 28.45it/s]

Loss: 0.43160842022535506 


100%|█████████▉| 255/256 [00:08<00:00, 28.56it/s]

Loss: 0.39777744665784653 


100%|██████████| 256/256 [00:08<00:00, 28.61it/s]
 21%|██        | 54/256 [00:01<00:07, 28.67it/s]

Loss: 0.46026329144833994 


 41%|████      | 105/256 [00:03<00:05, 28.44it/s]

Loss: 0.3974090532896302 


 60%|█████▉    | 153/256 [00:05<00:03, 28.70it/s]

Loss: 0.4155799075741454 


 80%|███████▉  | 204/256 [00:07<00:01, 28.37it/s]

Loss: 0.42023012582152985 


100%|█████████▉| 255/256 [00:08<00:00, 28.70it/s]

Loss: 0.39611864140670716 


100%|██████████| 256/256 [00:09<00:00, 28.44it/s]
 21%|██        | 54/256 [00:01<00:07, 28.77it/s]

Loss: 0.45238678956960254 


 41%|████      | 105/256 [00:03<00:05, 28.64it/s]

Loss: 0.3925796742598818 


 60%|█████▉    | 153/256 [00:05<00:03, 28.58it/s]

Loss: 0.4154326016417455 


 80%|███████▉  | 204/256 [00:07<00:01, 28.31it/s]

Loss: 0.41741487100028346 


100%|█████████▉| 255/256 [00:08<00:00, 28.30it/s]

Loss: 0.38542198553072715 


100%|██████████| 256/256 [00:09<00:00, 28.42it/s]
 21%|██        | 54/256 [00:01<00:07, 28.12it/s]

Loss: 0.4328358867517386 


 41%|████      | 105/256 [00:03<00:05, 28.07it/s]

Loss: 0.38667190607569196 


 60%|█████▉    | 153/256 [00:05<00:03, 27.81it/s]

Loss: 0.3950444993205531 


 80%|███████▉  | 204/256 [00:07<00:01, 28.15it/s]

Loss: 0.4065361015374866 


100%|█████████▉| 255/256 [00:09<00:00, 27.84it/s]

Loss: 0.38165765892260906 


100%|██████████| 256/256 [00:09<00:00, 27.94it/s]
 21%|██        | 54/256 [00:01<00:07, 27.89it/s]

Loss: 0.4330332415726471 


 41%|████      | 105/256 [00:03<00:05, 27.77it/s]

Loss: 0.3833692856058898 


 60%|█████▉    | 153/256 [00:05<00:03, 27.75it/s]

Loss: 0.3917375077255139 


 80%|███████▉  | 204/256 [00:07<00:01, 28.16it/s]

Loss: 0.39055842222669485 


100%|█████████▉| 255/256 [00:09<00:00, 27.52it/s]

Loss: 0.38547242455371056 


100%|██████████| 256/256 [00:09<00:00, 27.81it/s]
 21%|██        | 54/256 [00:01<00:07, 27.67it/s]

Loss: 0.4250794277272478 


 41%|████      | 105/256 [00:03<00:05, 27.82it/s]

Loss: 0.3771485501218599 


 60%|█████▉    | 153/256 [00:05<00:03, 27.94it/s]

Loss: 0.3845461145752468 


 80%|███████▉  | 204/256 [00:07<00:01, 28.17it/s]

Loss: 0.38696274735832753 


100%|█████████▉| 255/256 [00:09<00:00, 27.36it/s]

Loss: 0.3736136673740794 


100%|██████████| 256/256 [00:09<00:00, 27.85it/s]
 21%|██        | 54/256 [00:01<00:07, 28.05it/s]

Loss: 0.42442358915661255 


 41%|████      | 105/256 [00:03<00:05, 27.94it/s]

Loss: 0.37024183783185083 


 60%|█████▉    | 153/256 [00:05<00:03, 28.19it/s]

Loss: 0.3773911938236386 


 80%|███████▉  | 204/256 [00:07<00:01, 27.83it/s]

Loss: 0.3818275308596736 


100%|█████████▉| 255/256 [00:09<00:00, 27.57it/s]

Loss: 0.36233754493513615 


100%|██████████| 256/256 [00:09<00:00, 27.85it/s]
 21%|██        | 54/256 [00:01<00:07, 27.61it/s]

Loss: 0.42268008288953335 


 41%|████      | 105/256 [00:03<00:05, 27.65it/s]

Loss: 0.37235668669961136 


 60%|█████▉    | 153/256 [00:05<00:03, 27.69it/s]

Loss: 0.37040639530930364 


 80%|███████▉  | 204/256 [00:07<00:01, 27.57it/s]

Loss: 0.3893822252705557 


100%|█████████▉| 255/256 [00:09<00:00, 27.82it/s]

Loss: 0.3646309120494599 


100%|██████████| 256/256 [00:09<00:00, 27.61it/s]
 21%|██        | 54/256 [00:01<00:07, 27.85it/s]

Loss: 0.4204146408830999 


 41%|████      | 105/256 [00:03<00:05, 28.14it/s]

Loss: 0.37048120455514794 


 60%|█████▉    | 153/256 [00:05<00:03, 28.27it/s]

Loss: 0.3798582261462986 


 80%|███████▉  | 204/256 [00:07<00:01, 28.51it/s]

Loss: 0.381975397753913 


100%|█████████▉| 255/256 [00:09<00:00, 28.56it/s]

Loss: 0.3587973164099973 


100%|██████████| 256/256 [00:09<00:00, 28.25it/s]
 21%|██        | 54/256 [00:01<00:07, 28.54it/s]

Loss: 0.4144942530021647 


 41%|████      | 105/256 [00:03<00:05, 26.96it/s]

Loss: 0.36442893117967273 


 60%|█████▉    | 153/256 [00:05<00:03, 28.33it/s]

Loss: 0.36075137548604325 


 80%|███████▉  | 204/256 [00:07<00:01, 28.21it/s]

Loss: 0.37305752961068095 


100%|█████████▉| 255/256 [00:09<00:00, 28.40it/s]

Loss: 0.3577549526637548 


100%|██████████| 256/256 [00:09<00:00, 28.18it/s]
 21%|██        | 54/256 [00:01<00:07, 28.41it/s]

Loss: 0.4050633942523519 


 41%|████      | 105/256 [00:03<00:05, 28.44it/s]

Loss: 0.35418577033326293 


 60%|█████▉    | 153/256 [00:05<00:03, 28.40it/s]

Loss: 0.36091515865386464 


 80%|███████▉  | 204/256 [00:07<00:01, 28.69it/s]

Loss: 0.3695852732622106 


100%|█████████▉| 255/256 [00:08<00:00, 28.51it/s]

Loss: 0.36854863214543293 


100%|██████████| 256/256 [00:09<00:00, 28.37it/s]
 21%|██        | 54/256 [00:01<00:07, 28.64it/s]

Loss: 0.4107807057492919 


 41%|████      | 105/256 [00:03<00:05, 28.34it/s]

Loss: 0.3630922205856139 


 60%|█████▉    | 153/256 [00:05<00:03, 28.58it/s]

Loss: 0.3647224077756971 


 80%|███████▉  | 204/256 [00:07<00:01, 28.64it/s]

Loss: 0.3675310050505887 


100%|█████████▉| 255/256 [00:08<00:00, 28.60it/s]

Loss: 0.35614594452651227 


100%|██████████| 256/256 [00:08<00:00, 28.48it/s]
 21%|██        | 54/256 [00:01<00:07, 28.62it/s]

Loss: 0.40987248503303014 


 41%|████      | 105/256 [00:03<00:05, 28.78it/s]

Loss: 0.3571772573877623 


 60%|█████▉    | 153/256 [00:05<00:03, 28.90it/s]

Loss: 0.3573390050268836 


 80%|███████▉  | 204/256 [00:07<00:01, 28.57it/s]

Loss: 0.3598438045747704 


100%|█████████▉| 255/256 [00:08<00:00, 28.02it/s]

Loss: 0.351789104510852 


100%|██████████| 256/256 [00:09<00:00, 28.41it/s]
 21%|██        | 54/256 [00:01<00:07, 28.47it/s]

Loss: 0.4030003562661484 


 41%|████      | 105/256 [00:03<00:05, 28.67it/s]

Loss: 0.35178577815056505 


 60%|█████▉    | 153/256 [00:05<00:03, 28.73it/s]

Loss: 0.35542048207368654 


 80%|███████▉  | 204/256 [00:07<00:01, 28.32it/s]

Loss: 0.3690105341445706 


100%|█████████▉| 255/256 [00:08<00:00, 28.76it/s]

Loss: 0.3564894409924139 


100%|██████████| 256/256 [00:08<00:00, 28.49it/s]
 21%|██        | 54/256 [00:01<00:07, 27.82it/s]

Loss: 0.40199630617309445 


 41%|████      | 105/256 [00:03<00:05, 28.60it/s]

Loss: 0.35386684011116093 


 60%|█████▉    | 153/256 [00:05<00:03, 28.63it/s]

Loss: 0.3500550286804853 


 80%|███████▉  | 204/256 [00:07<00:01, 28.45it/s]

Loss: 0.35356490616444247 


100%|█████████▉| 255/256 [00:08<00:00, 28.66it/s]

Loss: 0.3464674093394801 


100%|██████████| 256/256 [00:09<00:00, 28.43it/s]
 21%|██        | 54/256 [00:01<00:07, 28.75it/s]

Loss: 0.3991874817598036 


 41%|████      | 105/256 [00:03<00:05, 28.61it/s]

Loss: 0.34939996443279464 


 60%|█████▉    | 153/256 [00:05<00:03, 28.73it/s]

Loss: 0.34520431464430956 


 80%|███████▉  | 204/256 [00:07<00:01, 28.40it/s]

Loss: 0.35351280368551435 


100%|█████████▉| 255/256 [00:08<00:00, 27.88it/s]

Loss: 0.3456548373286041 


100%|██████████| 256/256 [00:08<00:00, 28.56it/s]
 21%|██        | 54/256 [00:01<00:07, 28.73it/s]

Loss: 0.3928938145143574 


 41%|████      | 105/256 [00:03<00:05, 27.10it/s]

Loss: 0.3441947750468907 


 60%|█████▉    | 153/256 [00:05<00:03, 28.42it/s]

Loss: 0.35150220807696053 


 80%|███████▉  | 204/256 [00:07<00:01, 28.65it/s]

Loss: 0.3487980764086088 


100%|█████████▉| 255/256 [00:08<00:00, 28.73it/s]

Loss: 0.34265183455129106 


100%|██████████| 256/256 [00:09<00:00, 28.40it/s]
 21%|██        | 54/256 [00:01<00:06, 28.87it/s]

Loss: 0.39424060298929275 


 41%|████      | 105/256 [00:03<00:05, 28.73it/s]

Loss: 0.3544764151391795 


 60%|█████▉    | 153/256 [00:05<00:03, 28.76it/s]

Loss: 0.3451255919192826 


 80%|███████▉  | 204/256 [00:07<00:01, 28.72it/s]

Loss: 0.3551045905370696 


100%|█████████▉| 255/256 [00:08<00:00, 28.66it/s]

Loss: 0.3443565162815381 


100%|██████████| 256/256 [00:08<00:00, 28.73it/s]
 21%|██        | 54/256 [00:01<00:07, 28.25it/s]

Loss: 0.39306164632061125 


 41%|████      | 105/256 [00:03<00:05, 28.45it/s]

Loss: 0.3462609920498526 


 60%|█████▉    | 153/256 [00:05<00:03, 28.45it/s]

Loss: 0.34219861914473854 


 80%|███████▉  | 204/256 [00:07<00:01, 28.84it/s]

Loss: 0.340921445343602 


100%|█████████▉| 255/256 [00:08<00:00, 28.49it/s]

Loss: 0.33745698098364263 


100%|██████████| 256/256 [00:08<00:00, 28.46it/s]
 21%|██        | 54/256 [00:01<00:07, 28.77it/s]

Loss: 0.3955046562797319 


 41%|████      | 105/256 [00:03<00:05, 28.80it/s]

Loss: 0.34273270900266695 


 60%|█████▉    | 153/256 [00:05<00:03, 28.54it/s]

Loss: 0.3429944509166642 


 80%|███████▉  | 204/256 [00:07<00:01, 28.57it/s]

Loss: 0.3391184917802152 


100%|█████████▉| 255/256 [00:08<00:00, 27.48it/s]

Loss: 0.33033346878695996 


100%|██████████| 256/256 [00:09<00:00, 28.44it/s]
 21%|██        | 54/256 [00:01<00:07, 28.84it/s]

Loss: 0.37979412882286534 


 41%|████      | 105/256 [00:03<00:05, 27.67it/s]

Loss: 0.34223988317516224 


 60%|█████▉    | 153/256 [00:05<00:03, 28.70it/s]

Loss: 0.34472728895446003 


 80%|███████▉  | 204/256 [00:07<00:01, 28.77it/s]

Loss: 0.3509644932590411 


100%|█████████▉| 255/256 [00:08<00:00, 28.67it/s]

Loss: 0.33462833034570516 


100%|██████████| 256/256 [00:08<00:00, 28.55it/s]
 21%|██        | 54/256 [00:01<00:07, 28.79it/s]

Loss: 0.3797457229622749 


 41%|████      | 105/256 [00:03<00:05, 28.87it/s]

Loss: 0.3358716453384682 


 60%|█████▉    | 153/256 [00:05<00:03, 28.50it/s]

Loss: 0.34168514040762865 


 80%|███████▉  | 204/256 [00:07<00:01, 28.16it/s]

Loss: 0.34343181407594187 


100%|█████████▉| 255/256 [00:08<00:00, 28.45it/s]

Loss: 0.337262261977963 


100%|██████████| 256/256 [00:08<00:00, 28.62it/s]
 21%|██        | 54/256 [00:01<00:07, 28.60it/s]

Loss: 0.38066751504014035 


 41%|████      | 105/256 [00:03<00:05, 28.58it/s]

Loss: 0.3436833854956557 


 60%|█████▉    | 153/256 [00:05<00:03, 28.65it/s]

Loss: 0.3364137289307488 


 80%|███████▉  | 204/256 [00:07<00:01, 28.70it/s]

Loss: 0.3433983394102558 


100%|█████████▉| 255/256 [00:08<00:00, 28.69it/s]

Loss: 0.33240210880339277 


100%|██████████| 256/256 [00:08<00:00, 28.50it/s]
 21%|██        | 54/256 [00:01<00:07, 28.63it/s]

Loss: 0.3817511371011162 


 41%|████      | 105/256 [00:03<00:05, 28.25it/s]

Loss: 0.33865205677408633 


 60%|█████▉    | 153/256 [00:05<00:03, 28.44it/s]

Loss: 0.33852364662220696 


 80%|███████▉  | 204/256 [00:07<00:01, 28.33it/s]

Loss: 0.33786717311026526 


100%|█████████▉| 255/256 [00:08<00:00, 28.54it/s]

Loss: 0.3253213800590714 


100%|██████████| 256/256 [00:09<00:00, 28.37it/s]
 21%|██        | 54/256 [00:01<00:07, 28.72it/s]

Loss: 0.38590392796024986 


 41%|████      | 105/256 [00:03<00:05, 28.75it/s]

Loss: 0.33583777041299645 


 60%|█████▉    | 153/256 [00:05<00:03, 28.44it/s]

Loss: 0.33799640397717967 


 80%|███████▉  | 204/256 [00:07<00:01, 28.38it/s]

Loss: 0.3433800485142383 


100%|█████████▉| 255/256 [00:08<00:00, 28.62it/s]

Loss: 0.32918498202421154 


100%|██████████| 256/256 [00:08<00:00, 28.55it/s]
 21%|██        | 54/256 [00:01<00:07, 28.56it/s]

Loss: 0.3784970163405643 


 41%|████      | 105/256 [00:03<00:05, 28.53it/s]

Loss: 0.33567339563033466 


 60%|█████▉    | 153/256 [00:05<00:03, 28.82it/s]

Loss: 0.3383453470597047 


 80%|███████▉  | 204/256 [00:07<00:01, 28.84it/s]

Loss: 0.3360298550081558 


100%|█████████▉| 255/256 [00:08<00:00, 28.50it/s]

Loss: 0.33328979481252163 


100%|██████████| 256/256 [00:08<00:00, 28.68it/s]
 21%|██        | 54/256 [00:01<00:07, 28.77it/s]

Loss: 0.37979106904191196 


 41%|████      | 105/256 [00:03<00:05, 28.67it/s]

Loss: 0.3312032323702694 


 60%|█████▉    | 153/256 [00:05<00:03, 28.43it/s]

Loss: 0.33618656722755724 


 80%|███████▉  | 204/256 [00:07<00:01, 28.99it/s]

Loss: 0.3317269709508334 


100%|█████████▉| 255/256 [00:08<00:00, 27.46it/s]

Loss: 0.331558663144191 


100%|██████████| 256/256 [00:08<00:00, 28.59it/s]
 21%|██        | 54/256 [00:01<00:06, 28.86it/s]

Loss: 0.37770000822325456 


 41%|████      | 105/256 [00:03<00:05, 28.64it/s]

Loss: 0.33540196796804345 


 60%|█████▉    | 153/256 [00:05<00:03, 28.85it/s]

Loss: 0.33544976106737756 


 80%|███████▉  | 204/256 [00:07<00:01, 28.74it/s]

Loss: 0.32996822282824395 


100%|█████████▉| 255/256 [00:08<00:00, 28.76it/s]

Loss: 0.33169301734036644 


100%|██████████| 256/256 [00:08<00:00, 28.65it/s]
 21%|██        | 54/256 [00:01<00:07, 28.57it/s]

Loss: 0.38319272960133915 


 41%|████      | 105/256 [00:03<00:05, 28.66it/s]

Loss: 0.3351315515845353 


 60%|█████▉    | 153/256 [00:05<00:03, 28.52it/s]

Loss: 0.3398017044608072 


 80%|███████▉  | 204/256 [00:07<00:01, 28.66it/s]

Loss: 0.32947035331592495 


100%|█████████▉| 255/256 [00:08<00:00, 28.35it/s]

Loss: 0.3254620304235133 


100%|██████████| 256/256 [00:08<00:00, 28.51it/s]
 21%|██        | 54/256 [00:01<00:07, 28.68it/s]

Loss: 0.3775853383080743 


 41%|████      | 105/256 [00:03<00:05, 28.77it/s]

Loss: 0.33281569307784103 


 60%|█████▉    | 153/256 [00:05<00:03, 28.84it/s]

Loss: 0.33292356614011953 


 80%|███████▉  | 204/256 [00:07<00:01, 28.14it/s]

Loss: 0.3340671159632278 


100%|█████████▉| 255/256 [00:08<00:00, 28.91it/s]

Loss: 0.3294155922227071 


100%|██████████| 256/256 [00:08<00:00, 28.60it/s]
 21%|██        | 54/256 [00:01<00:06, 28.93it/s]

Loss: 0.378563776607631 


 41%|████      | 105/256 [00:03<00:05, 28.94it/s]

Loss: 0.3318627482188831 


 60%|█████▉    | 153/256 [00:05<00:03, 27.71it/s]

Loss: 0.33212620886377786 


 80%|███████▉  | 204/256 [00:07<00:01, 28.82it/s]

Loss: 0.3277991712304362 


100%|█████████▉| 255/256 [00:08<00:00, 28.68it/s]

Loss: 0.3251615259302229 


100%|██████████| 256/256 [00:08<00:00, 28.56it/s]
 21%|██        | 54/256 [00:01<00:07, 28.56it/s]

Loss: 0.3865079317858104 


 41%|████      | 105/256 [00:03<00:05, 28.67it/s]

Loss: 0.3376999160179757 


 60%|█████▉    | 153/256 [00:05<00:03, 28.46it/s]

Loss: 0.33532853406725605 


 80%|███████▉  | 204/256 [00:07<00:01, 28.44it/s]

Loss: 0.3309702118980325 


100%|█████████▉| 255/256 [00:08<00:00, 28.28it/s]

Loss: 0.3262597915569963 


100%|██████████| 256/256 [00:08<00:00, 28.57it/s]
 21%|██        | 54/256 [00:01<00:07, 28.82it/s]

Loss: 0.37582612508748675 


 41%|████      | 105/256 [00:03<00:05, 28.72it/s]

Loss: 0.336321844099413 


 60%|█████▉    | 153/256 [00:05<00:03, 28.83it/s]

Loss: 0.3357038818500572 


 80%|███████▉  | 204/256 [00:07<00:01, 28.55it/s]

Loss: 0.3293752431229499 


100%|█████████▉| 255/256 [00:08<00:00, 28.32it/s]

Loss: 0.3230516435772198 


100%|██████████| 256/256 [00:08<00:00, 28.56it/s]
 21%|██        | 54/256 [00:01<00:07, 28.72it/s]

Loss: 0.3767630089829418 


 41%|████      | 105/256 [00:03<00:05, 28.92it/s]

Loss: 0.3304304567423002 


 60%|█████▉    | 153/256 [00:05<00:03, 28.62it/s]

Loss: 0.3317765225201255 


 80%|███████▉  | 204/256 [00:07<00:01, 28.72it/s]

Loss: 0.33247416754547116 


100%|█████████▉| 255/256 [00:08<00:00, 28.61it/s]

Loss: 0.3208194487283595 


100%|██████████| 256/256 [00:08<00:00, 28.68it/s]
 21%|██        | 54/256 [00:01<00:06, 28.90it/s]

Loss: 0.37252115587338713 


 41%|████      | 105/256 [00:03<00:05, 28.87it/s]

Loss: 0.328904825978436 


 60%|█████▉    | 153/256 [00:05<00:03, 28.96it/s]

Loss: 0.3296540530198429 


 80%|███████▉  | 204/256 [00:07<00:01, 28.93it/s]

Loss: 0.32818044735558083 


100%|█████████▉| 255/256 [00:08<00:00, 28.74it/s]

Loss: 0.3225625140019081 


100%|██████████| 256/256 [00:08<00:00, 28.66it/s]
 21%|██        | 54/256 [00:01<00:06, 28.86it/s]

Loss: 0.3774457511990621 


 41%|████      | 105/256 [00:03<00:05, 28.48it/s]

Loss: 0.33731664024112085 


 60%|█████▉    | 153/256 [00:05<00:03, 28.40it/s]

Loss: 0.3396073603532954 


 80%|███████▉  | 204/256 [00:07<00:01, 28.58it/s]

Loss: 0.3327853929388246 


100%|█████████▉| 255/256 [00:08<00:00, 28.55it/s]

Loss: 0.3290408472167503 


100%|██████████| 256/256 [00:08<00:00, 28.58it/s]
 21%|██        | 54/256 [00:01<00:07, 28.58it/s]

Loss: 0.37718968432932237 


 41%|████      | 105/256 [00:03<00:05, 28.57it/s]

Loss: 0.3409682420957518 


 60%|█████▉    | 153/256 [00:05<00:03, 28.40it/s]

Loss: 0.33607431919158903 


 80%|███████▉  | 204/256 [00:07<00:01, 28.81it/s]

Loss: 0.3278509055677148 


100%|█████████▉| 255/256 [00:08<00:00, 28.44it/s]

Loss: 0.3254465033959473 


100%|██████████| 256/256 [00:08<00:00, 28.59it/s]
 21%|██        | 54/256 [00:01<00:07, 28.07it/s]

Loss: 0.3714309687920261 


 41%|████      | 105/256 [00:03<00:05, 28.39it/s]

Loss: 0.3299472704248938 


 60%|█████▉    | 153/256 [00:05<00:03, 26.85it/s]

Loss: 0.3314459157981214 


 80%|███████▉  | 204/256 [00:07<00:01, 28.78it/s]

Loss: 0.3251570118460795 


100%|█████████▉| 255/256 [00:09<00:00, 28.17it/s]

Loss: 0.3190784281508726 


100%|██████████| 256/256 [00:09<00:00, 27.81it/s]
 21%|██        | 54/256 [00:01<00:07, 28.60it/s]

Loss: 0.37687620225338747 


 41%|████      | 105/256 [00:03<00:05, 28.39it/s]

Loss: 0.3286876246159201 


 60%|█████▉    | 153/256 [00:05<00:03, 28.58it/s]

Loss: 0.33120310311944096 


 80%|███████▉  | 204/256 [00:07<00:01, 28.55it/s]

Loss: 0.3271591363883495 


100%|█████████▉| 255/256 [00:08<00:00, 28.54it/s]

Loss: 0.322598783632436 


100%|██████████| 256/256 [00:09<00:00, 28.43it/s]
 21%|██        | 54/256 [00:01<00:07, 27.89it/s]

Loss: 0.3842070798542385 


 41%|████      | 105/256 [00:03<00:05, 28.65it/s]

Loss: 0.33626288649652 


 60%|█████▉    | 153/256 [00:05<00:03, 28.42it/s]

Loss: 0.33012826686481206 


 80%|███████▉  | 204/256 [00:07<00:01, 28.31it/s]

Loss: 0.32340800042749734 


100%|█████████▉| 255/256 [00:08<00:00, 28.25it/s]

Loss: 0.31978539578622184 


100%|██████████| 256/256 [00:09<00:00, 28.34it/s]
 21%|██        | 54/256 [00:01<00:07, 27.35it/s]

Loss: 0.36848859789965993 


 41%|████      | 105/256 [00:03<00:05, 26.59it/s]

Loss: 0.33022627508991925 


 60%|█████▉    | 153/256 [00:05<00:03, 27.36it/s]

Loss: 0.32993372634883494 


 80%|███████▉  | 204/256 [00:07<00:01, 28.48it/s]

Loss: 0.3279923947244808 


100%|█████████▉| 255/256 [00:09<00:00, 26.72it/s]

Loss: 0.3184992704172506 


100%|██████████| 256/256 [00:09<00:00, 27.52it/s]
 21%|██        | 54/256 [00:01<00:07, 28.64it/s]

Loss: 0.36629227876058074 


 41%|████      | 105/256 [00:03<00:05, 28.50it/s]

Loss: 0.32806311779198466 


 60%|█████▉    | 153/256 [00:05<00:03, 28.42it/s]

Loss: 0.32862409941524534 


 80%|███████▉  | 204/256 [00:07<00:01, 28.53it/s]

Loss: 0.32310696507157927 


100%|█████████▉| 255/256 [00:08<00:00, 28.53it/s]

Loss: 0.31883039566693744 


100%|██████████| 256/256 [00:09<00:00, 28.34it/s]
 21%|██        | 54/256 [00:01<00:07, 28.61it/s]

Loss: 0.36702567479310394 


 41%|████      | 105/256 [00:03<00:05, 27.91it/s]

Loss: 0.32941158635049567 


 60%|█████▉    | 153/256 [00:05<00:03, 27.99it/s]

Loss: 0.34162346199494315 


 80%|███████▉  | 204/256 [00:07<00:01, 28.02it/s]

Loss: 0.3315245981008806 


100%|█████████▉| 255/256 [00:09<00:00, 28.25it/s]

Loss: 0.325472168167094 


100%|██████████| 256/256 [00:09<00:00, 28.14it/s]


Duration: 452 seconds





In [23]:
losses_skin_race = train_model(skin_race_model, epochs = 50, checkpoint = 'skintone_race_model', test = False)

 21%|██        | 53/256 [00:01<00:06, 32.44it/s]

Loss: 2.292836480869581 


 41%|████      | 105/256 [00:03<00:04, 32.61it/s]

Loss: 2.192550647109042 


 60%|█████▉    | 153/256 [00:04<00:03, 32.40it/s]

Loss: 2.1655663503724196 


 80%|████████  | 205/256 [00:06<00:01, 31.10it/s]

Loss: 2.062934408287636 


100%|██████████| 256/256 [00:08<00:00, 31.63it/s]
  train_acc.append(np.sum(result_epoch) / np.array(result_epoch).size)


Loss: 2.223969644778002 


 22%|██▏       | 56/256 [00:01<00:06, 32.56it/s]

Loss: 2.255747671846984 


 41%|████      | 104/256 [00:03<00:04, 32.67it/s]

Loss: 2.0588193711171945 


 61%|██████    | 156/256 [00:04<00:03, 32.44it/s]

Loss: 2.075730406138197 


 80%|███████▉  | 204/256 [00:06<00:01, 32.87it/s]

Loss: 1.9804036280986876 


100%|██████████| 256/256 [00:07<00:00, 32.53it/s]


Loss: 2.189768713427716 


 22%|██▏       | 56/256 [00:01<00:06, 32.18it/s]

Loss: 2.1989044064957435 


 41%|████      | 104/256 [00:03<00:04, 32.73it/s]

Loss: 2.019680592048394 


 61%|██████    | 156/256 [00:04<00:03, 32.81it/s]

Loss: 2.0233707196578177 


 80%|███████▉  | 204/256 [00:06<00:01, 32.26it/s]

Loss: 1.9356801438004674 


100%|██████████| 256/256 [00:07<00:00, 32.40it/s]


Loss: 2.0796917687029697 


 22%|██▏       | 56/256 [00:01<00:06, 32.44it/s]

Loss: 2.143853956681947 


 41%|████      | 104/256 [00:03<00:05, 29.99it/s]

Loss: 1.9499367265313265 


 61%|██████    | 156/256 [00:04<00:03, 32.38it/s]

Loss: 1.9799473719437772 


 80%|███████▉  | 204/256 [00:06<00:01, 32.92it/s]

Loss: 1.8970955924215607 


100%|██████████| 256/256 [00:07<00:00, 32.31it/s]


Loss: 2.041352975222396 


 22%|██▏       | 56/256 [00:01<00:06, 33.13it/s]

Loss: 2.1179452604466498 


 41%|████      | 104/256 [00:03<00:04, 32.93it/s]

Loss: 1.9187471794040685 


 61%|██████    | 156/256 [00:04<00:03, 32.95it/s]

Loss: 1.9566742903802319 


 80%|███████▉  | 204/256 [00:06<00:01, 32.32it/s]

Loss: 1.8529316962387101 


100%|██████████| 256/256 [00:07<00:00, 32.82it/s]

Loss: 2.0098264963414723 



 22%|██▏       | 56/256 [00:01<00:06, 32.77it/s]

Loss: 2.0728766768149582 


 41%|████      | 104/256 [00:03<00:04, 33.25it/s]

Loss: 1.9126486827464615 


 61%|██████    | 156/256 [00:04<00:03, 33.27it/s]

Loss: 1.9252639382089711 


 80%|███████▉  | 204/256 [00:06<00:01, 33.20it/s]

Loss: 1.821934276153128 


100%|██████████| 256/256 [00:07<00:00, 32.91it/s]


Loss: 1.9833949857828201 


 22%|██▏       | 56/256 [00:01<00:06, 33.21it/s]

Loss: 2.054628719967978 


 41%|████      | 104/256 [00:03<00:04, 33.11it/s]

Loss: 1.8612499393658295 


 61%|██████    | 156/256 [00:04<00:03, 32.99it/s]

Loss: 1.8702972517208036 


 80%|███████▉  | 204/256 [00:06<00:01, 32.89it/s]

Loss: 1.8014597158419536 


100%|██████████| 256/256 [00:07<00:00, 32.83it/s]


Loss: 1.9533897123280084 


 22%|██▏       | 56/256 [00:01<00:06, 32.47it/s]

Loss: 2.013071021631093 


 41%|████      | 104/256 [00:03<00:04, 32.89it/s]

Loss: 1.8408564271486145 


 61%|██████    | 156/256 [00:04<00:03, 33.09it/s]

Loss: 1.8641762450432728 


 80%|███████▉  | 204/256 [00:06<00:01, 32.81it/s]

Loss: 1.774833941283657 


100%|██████████| 256/256 [00:07<00:00, 32.68it/s]


Loss: 1.9027840094542612 


 22%|██▏       | 56/256 [00:01<00:06, 32.85it/s]

Loss: 1.9904955914427094 


 41%|████      | 104/256 [00:03<00:04, 32.48it/s]

Loss: 1.8140263498954463 


 61%|██████    | 156/256 [00:04<00:03, 31.87it/s]

Loss: 1.8494469596627527 


 80%|███████▉  | 204/256 [00:06<00:01, 31.82it/s]

Loss: 1.7637742618581895 


100%|██████████| 256/256 [00:07<00:00, 32.15it/s]


Loss: 1.904024583207612 


 22%|██▏       | 56/256 [00:01<00:06, 32.62it/s]

Loss: 1.9965862485549726 


 41%|████      | 104/256 [00:03<00:04, 32.93it/s]

Loss: 1.8011566853923435 


 61%|██████    | 156/256 [00:04<00:03, 32.99it/s]

Loss: 1.8062495346698981 


 80%|███████▉  | 204/256 [00:06<00:01, 32.09it/s]

Loss: 1.7323515616490024 


100%|██████████| 256/256 [00:07<00:00, 32.77it/s]

Loss: 1.856481892104262 



 21%|██▏       | 55/256 [00:01<00:06, 32.95it/s]

Loss: 1.958531099526071 


 40%|████      | 103/256 [00:03<00:04, 32.99it/s]

Loss: 1.7784058550310047 


 61%|██████    | 155/256 [00:04<00:03, 32.95it/s]

Loss: 1.7790649674964498 


 79%|███████▉  | 203/256 [00:06<00:01, 32.81it/s]

Loss: 1.6980073039228012 


100%|██████████| 256/256 [00:07<00:00, 32.74it/s]


Loss: 1.835581089999862 


 22%|██▏       | 56/256 [00:01<00:06, 32.86it/s]

Loss: 1.9414946592624274 


 41%|████      | 104/256 [00:03<00:04, 32.97it/s]

Loss: 1.77614932086593 


 61%|██████    | 156/256 [00:04<00:03, 32.74it/s]

Loss: 1.7462855288999375 


 80%|███████▉  | 204/256 [00:06<00:01, 32.40it/s]

Loss: 1.6823706904212927 


100%|██████████| 256/256 [00:07<00:00, 32.71it/s]


Loss: 1.8033421146582527 


 22%|██▏       | 56/256 [00:01<00:06, 32.55it/s]

Loss: 1.9154579451473421 


 41%|████      | 104/256 [00:03<00:04, 32.75it/s]

Loss: 1.7734309405183724 


 61%|██████    | 156/256 [00:04<00:03, 32.86it/s]

Loss: 1.7651077247637144 


 80%|███████▉  | 204/256 [00:06<00:01, 31.35it/s]

Loss: 1.6653387322843411 


100%|██████████| 256/256 [00:07<00:00, 32.41it/s]


Loss: 1.7835919756056482 


 22%|██▏       | 56/256 [00:01<00:06, 32.41it/s]

Loss: 1.91401844452159 


 41%|████      | 104/256 [00:03<00:04, 32.75it/s]

Loss: 1.7175876042618057 


 61%|██████    | 156/256 [00:04<00:03, 31.66it/s]

Loss: 1.7171471536118694 


 80%|███████▉  | 204/256 [00:06<00:01, 32.78it/s]

Loss: 1.6405547651038888 


100%|██████████| 256/256 [00:07<00:00, 32.49it/s]


Loss: 1.7848471621242856 


 22%|██▏       | 56/256 [00:01<00:06, 32.47it/s]

Loss: 1.8917835196625647 


 41%|████      | 104/256 [00:03<00:04, 31.40it/s]

Loss: 1.7009549800737733 


 61%|██████    | 156/256 [00:04<00:03, 32.88it/s]

Loss: 1.7184910927655563 


 80%|███████▉  | 204/256 [00:06<00:01, 33.00it/s]

Loss: 1.6289302510992525 


100%|██████████| 256/256 [00:07<00:00, 32.49it/s]

Loss: 1.75865923416494 



 22%|██▏       | 56/256 [00:01<00:06, 32.92it/s]

Loss: 1.86777786570725 


 41%|████      | 104/256 [00:03<00:04, 32.89it/s]

Loss: 1.7001308680391727 


 61%|██████    | 156/256 [00:04<00:03, 32.82it/s]

Loss: 1.708728460236845 


 80%|███████▉  | 204/256 [00:06<00:01, 32.95it/s]

Loss: 1.6250697043570372 


100%|██████████| 256/256 [00:07<00:00, 32.73it/s]


Loss: 1.7301623500131773 


 22%|██▏       | 56/256 [00:01<00:06, 32.61it/s]

Loss: 1.8606927660623152 


 41%|████      | 104/256 [00:03<00:04, 32.52it/s]

Loss: 1.678580931585941 


 61%|██████    | 156/256 [00:04<00:03, 32.96it/s]

Loss: 1.6702199480095559 


 80%|███████▉  | 204/256 [00:06<00:01, 30.65it/s]

Loss: 1.585918014167888 


100%|██████████| 256/256 [00:07<00:00, 32.34it/s]


Loss: 1.725794088042076 


 22%|██▏       | 56/256 [00:01<00:06, 32.94it/s]

Loss: 1.8310975166343986 


 41%|████      | 104/256 [00:03<00:04, 32.01it/s]

Loss: 1.6493389730671046 


 61%|██████    | 156/256 [00:04<00:03, 32.22it/s]

Loss: 1.663214757489799 


 80%|███████▉  | 204/256 [00:06<00:01, 32.39it/s]

Loss: 1.6017328023443351 


100%|██████████| 256/256 [00:07<00:00, 32.55it/s]


Loss: 1.7229538433831468 


 22%|██▏       | 56/256 [00:01<00:06, 32.42it/s]

Loss: 1.8295175612339674 


 41%|████      | 104/256 [00:03<00:04, 32.65it/s]

Loss: 1.6466414848499047 


 61%|██████    | 156/256 [00:04<00:03, 32.52it/s]

Loss: 1.6456706759081774 


 80%|███████▉  | 204/256 [00:06<00:01, 32.28it/s]

Loss: 1.5687558933850654 


100%|██████████| 256/256 [00:07<00:00, 32.33it/s]


Loss: 1.697607934132305 


 22%|██▏       | 56/256 [00:01<00:06, 32.11it/s]

Loss: 1.8055821963447416 


 41%|████      | 104/256 [00:03<00:04, 32.61it/s]

Loss: 1.627339950425377 


 61%|██████    | 156/256 [00:04<00:03, 32.41it/s]

Loss: 1.6417228288908086 


 80%|███████▉  | 204/256 [00:06<00:01, 32.59it/s]

Loss: 1.5740925184538215 


100%|██████████| 256/256 [00:07<00:00, 32.34it/s]

Loss: 1.679041272240989 



 22%|██▏       | 56/256 [00:01<00:06, 32.58it/s]

Loss: 1.7932420769372448 


 41%|████      | 104/256 [00:03<00:04, 32.84it/s]

Loss: 1.6142659353118767 


 61%|██████    | 156/256 [00:04<00:03, 32.64it/s]

Loss: 1.6302987328047618 


 80%|███████▉  | 204/256 [00:06<00:01, 31.62it/s]

Loss: 1.571145155856462 


100%|██████████| 256/256 [00:07<00:00, 32.29it/s]


Loss: 1.669218046559337 


 22%|██▏       | 56/256 [00:01<00:06, 32.07it/s]

Loss: 1.806835777040031 


 41%|████      | 104/256 [00:03<00:04, 32.50it/s]

Loss: 1.614238905579966 


 61%|██████    | 156/256 [00:04<00:03, 32.72it/s]

Loss: 1.607111123328996 


 80%|███████▉  | 204/256 [00:06<00:01, 32.62it/s]

Loss: 1.552457820460759 


100%|██████████| 256/256 [00:07<00:00, 32.51it/s]


Loss: 1.6763862139081305 


 22%|██▏       | 56/256 [00:01<00:06, 32.30it/s]

Loss: 1.79899347130931 


 41%|████      | 104/256 [00:03<00:04, 32.65it/s]

Loss: 1.6186844708419006 


 61%|██████    | 156/256 [00:04<00:03, 31.83it/s]

Loss: 1.6114936927312635 


 80%|███████▉  | 204/256 [00:06<00:01, 32.15it/s]

Loss: 1.5452314738766626 


100%|██████████| 256/256 [00:07<00:00, 32.14it/s]


Loss: 1.6598387044674936 


 22%|██▏       | 56/256 [00:01<00:06, 32.52it/s]

Loss: 1.7761571696202796 


 41%|████      | 104/256 [00:03<00:04, 32.57it/s]

Loss: 1.599595408807128 


 61%|██████    | 156/256 [00:04<00:03, 32.75it/s]

Loss: 1.588646353552894 


 80%|███████▉  | 204/256 [00:06<00:01, 32.16it/s]

Loss: 1.530637321609647 


100%|██████████| 256/256 [00:07<00:00, 32.42it/s]


Loss: 1.6488489619407458 


 22%|██▏       | 56/256 [00:01<00:06, 32.42it/s]

Loss: 1.7695383735699755 


 41%|████      | 104/256 [00:03<00:04, 32.71it/s]

Loss: 1.5906538285632161 


 61%|██████    | 156/256 [00:04<00:03, 32.66it/s]

Loss: 1.578891470385122 


 80%|███████▉  | 204/256 [00:06<00:01, 31.11it/s]

Loss: 1.518157044356421 


100%|██████████| 256/256 [00:07<00:00, 32.30it/s]

Loss: 1.6327470091602743 



 22%|██▏       | 56/256 [00:01<00:06, 32.54it/s]

Loss: 1.7670680761073163 


 41%|████      | 104/256 [00:03<00:04, 32.68it/s]

Loss: 1.5814166410134638 


 61%|██████    | 156/256 [00:04<00:03, 32.41it/s]

Loss: 1.5672328406461877 


 80%|███████▉  | 204/256 [00:06<00:01, 32.50it/s]

Loss: 1.4907183508021438 


100%|██████████| 256/256 [00:07<00:00, 32.44it/s]


Loss: 1.6324080252440438 


 22%|██▏       | 56/256 [00:01<00:06, 32.32it/s]

Loss: 1.738216445356924 


 41%|████      | 104/256 [00:03<00:04, 31.84it/s]

Loss: 1.563801266853446 


 61%|██████    | 156/256 [00:04<00:03, 32.45it/s]

Loss: 1.5593682427005535 


 80%|███████▉  | 204/256 [00:06<00:01, 32.04it/s]

Loss: 1.48937066434338 


100%|██████████| 256/256 [00:07<00:00, 32.20it/s]


Loss: 1.6156958699117498 


 22%|██▏       | 56/256 [00:01<00:06, 32.41it/s]

Loss: 1.7491794844548259 


 41%|████      | 104/256 [00:03<00:04, 32.41it/s]

Loss: 1.575926052820166 


 61%|██████    | 156/256 [00:04<00:03, 32.33it/s]

Loss: 1.6562350529525613 


 80%|███████▉  | 204/256 [00:06<00:01, 32.25it/s]

Loss: 1.5740023960438743 


100%|██████████| 256/256 [00:07<00:00, 32.27it/s]


Loss: 1.677052299202571 


 22%|██▏       | 56/256 [00:01<00:06, 32.61it/s]

Loss: 1.8174418978609905 


 41%|████      | 104/256 [00:03<00:04, 32.70it/s]

Loss: 1.6062527564117368 


 61%|██████    | 156/256 [00:04<00:03, 32.56it/s]

Loss: 1.6092159126565384 


 80%|███████▉  | 204/256 [00:06<00:01, 30.83it/s]

Loss: 1.5131414690999656 


100%|██████████| 256/256 [00:07<00:00, 32.35it/s]


Loss: 1.6342108497593957 


 22%|██▏       | 56/256 [00:01<00:06, 32.58it/s]

Loss: 1.736801493019449 


 41%|████      | 104/256 [00:03<00:04, 32.10it/s]

Loss: 1.562426860009101 


 61%|██████    | 156/256 [00:04<00:03, 32.46it/s]

Loss: 1.5382948232056068 


 80%|███████▉  | 204/256 [00:06<00:01, 32.20it/s]

Loss: 1.4864073403162579 


100%|██████████| 256/256 [00:07<00:00, 32.30it/s]

Loss: 1.5996869710698758 



 21%|██▏       | 55/256 [00:01<00:06, 32.29it/s]

Loss: 1.7304951288596482 


 40%|████      | 103/256 [00:03<00:04, 32.25it/s]

Loss: 1.5663815299903403 


 61%|██████    | 155/256 [00:04<00:03, 32.60it/s]

Loss: 1.5291073375772646 


 79%|███████▉  | 203/256 [00:06<00:01, 32.11it/s]

Loss: 1.4561150455944758 


100%|██████████| 256/256 [00:07<00:00, 32.21it/s]


Loss: 1.6118247209073955 


 22%|██▏       | 56/256 [00:01<00:06, 32.23it/s]

Loss: 1.7297022046574 


 41%|████      | 104/256 [00:03<00:04, 32.61it/s]

Loss: 1.554165257085988 


 61%|██████    | 156/256 [00:04<00:03, 32.76it/s]

Loss: 1.5327017552849977 


 80%|███████▉  | 204/256 [00:06<00:01, 32.16it/s]

Loss: 1.4756225072273121 


100%|██████████| 256/256 [00:07<00:00, 32.23it/s]


Loss: 1.596641713126326 


 22%|██▏       | 56/256 [00:01<00:06, 32.07it/s]

Loss: 1.7228597480949799 


 41%|████      | 104/256 [00:03<00:04, 32.50it/s]

Loss: 1.5545625406399106 


 61%|██████    | 156/256 [00:04<00:03, 32.31it/s]

Loss: 1.5210920810621535 


 80%|███████▉  | 204/256 [00:06<00:01, 31.52it/s]

Loss: 1.4587770290073616 


100%|██████████| 256/256 [00:07<00:00, 32.17it/s]


Loss: 1.5857867873192821 


 22%|██▏       | 56/256 [00:01<00:06, 32.15it/s]

Loss: 1.7301591353387298 


 41%|████      | 104/256 [00:03<00:04, 32.63it/s]

Loss: 1.540970173237019 


 61%|██████    | 156/256 [00:04<00:03, 32.48it/s]

Loss: 1.5311908351264407 


 80%|███████▉  | 204/256 [00:06<00:01, 31.92it/s]

Loss: 1.4455737250427314 


100%|██████████| 256/256 [00:07<00:00, 32.42it/s]


Loss: 1.5869679315780871 


 22%|██▏       | 56/256 [00:01<00:06, 32.60it/s]

Loss: 1.7166628287380354 


 41%|████      | 104/256 [00:03<00:04, 32.38it/s]

Loss: 1.5384332055509475 


 61%|██████    | 156/256 [00:04<00:03, 32.71it/s]

Loss: 1.5196602695647372 


 80%|███████▉  | 204/256 [00:06<00:01, 32.50it/s]

Loss: 1.4520912413590055 


100%|██████████| 256/256 [00:07<00:00, 32.44it/s]

Loss: 1.5690962697507604 



 22%|██▏       | 56/256 [00:01<00:06, 32.21it/s]

Loss: 1.69837389168317 


 41%|████      | 104/256 [00:03<00:04, 32.29it/s]

Loss: 1.5370188047583384 


 61%|██████    | 156/256 [00:04<00:03, 32.24it/s]

Loss: 1.5130870917091295 


 80%|███████▉  | 204/256 [00:06<00:01, 32.26it/s]

Loss: 1.4533229049191094 


100%|██████████| 256/256 [00:07<00:00, 32.17it/s]


Loss: 1.5688367173790783 


 22%|██▏       | 56/256 [00:01<00:06, 32.80it/s]

Loss: 1.6947990386638265 


 41%|████      | 104/256 [00:03<00:04, 32.73it/s]

Loss: 1.525938896312832 


 61%|██████    | 156/256 [00:04<00:03, 32.61it/s]

Loss: 1.503318459124992 


 80%|███████▉  | 204/256 [00:06<00:01, 30.42it/s]

Loss: 1.4367139454545674 


100%|██████████| 256/256 [00:07<00:00, 32.21it/s]


Loss: 1.5777086114110557 


 22%|██▏       | 56/256 [00:01<00:06, 32.42it/s]

Loss: 1.697304270635351 


 41%|████      | 104/256 [00:03<00:04, 32.51it/s]

Loss: 1.5226457263932895 


 61%|██████    | 156/256 [00:04<00:03, 31.68it/s]

Loss: 1.4963370561058138 


 80%|███████▉  | 204/256 [00:06<00:01, 32.22it/s]

Loss: 1.4301500245897063 


100%|██████████| 256/256 [00:07<00:00, 32.04it/s]


Loss: 1.566691987351277 


 22%|██▏       | 56/256 [00:01<00:06, 31.03it/s]

Loss: 1.6953401708611004 


 41%|████      | 104/256 [00:03<00:04, 31.02it/s]

Loss: 1.5249374910730082 


 61%|██████    | 156/256 [00:05<00:03, 30.19it/s]

Loss: 1.5275994422579982 


 80%|███████▉  | 204/256 [00:06<00:01, 30.06it/s]

Loss: 1.441461441028353 


100%|██████████| 256/256 [00:08<00:00, 31.00it/s]

Loss: 1.5867103145467802 


100%|██████████| 256/256 [00:08<00:00, 30.90it/s]
 22%|██▏       | 56/256 [00:01<00:06, 31.36it/s]

Loss: 1.6983039328630962 


 41%|████      | 104/256 [00:03<00:04, 31.53it/s]

Loss: 1.5127117069472162 


 61%|██████    | 156/256 [00:04<00:03, 32.62it/s]

Loss: 1.5195833761693829 


 80%|███████▉  | 204/256 [00:06<00:01, 32.48it/s]

Loss: 1.450115242647023 


100%|██████████| 256/256 [00:08<00:00, 31.97it/s]

Loss: 1.5894459643913288 



 21%|██▏       | 55/256 [00:01<00:06, 32.64it/s]

Loss: 1.6922280440366058 


 40%|████      | 103/256 [00:03<00:04, 32.31it/s]

Loss: 1.5414570201911684 


 61%|██████    | 155/256 [00:04<00:03, 32.59it/s]

Loss: 1.5092396355311548 


 79%|███████▉  | 203/256 [00:06<00:01, 32.48it/s]

Loss: 1.429751493468482 


100%|██████████| 256/256 [00:07<00:00, 32.36it/s]


Loss: 1.5646810951619299 


 22%|██▏       | 56/256 [00:01<00:06, 32.77it/s]

Loss: 1.6841991504496432 


 41%|████      | 104/256 [00:03<00:04, 32.64it/s]

Loss: 1.5040991830895225 


 61%|██████    | 156/256 [00:04<00:03, 32.50it/s]

Loss: 1.4899446967149452 


 80%|███████▉  | 204/256 [00:06<00:01, 32.73it/s]

Loss: 1.4137390701948478 


100%|██████████| 256/256 [00:07<00:00, 32.56it/s]


Loss: 1.5601671938846096 


 22%|██▏       | 56/256 [00:01<00:06, 32.75it/s]

Loss: 1.683092867338001 


 41%|████      | 104/256 [00:03<00:04, 32.75it/s]

Loss: 1.479830353785513 


 61%|██████    | 156/256 [00:04<00:03, 32.27it/s]

Loss: 1.4839901282636345 


 80%|███████▉  | 204/256 [00:06<00:01, 32.48it/s]

Loss: 1.4072795256347155 


100%|██████████| 256/256 [00:07<00:00, 32.27it/s]


Loss: 1.560880788884891 


 22%|██▏       | 56/256 [00:01<00:06, 32.49it/s]

Loss: 1.6719792133658948 


 41%|████      | 104/256 [00:03<00:04, 32.80it/s]

Loss: 1.517431794895843 


 61%|██████    | 156/256 [00:04<00:03, 32.58it/s]

Loss: 1.489690792099432 


 80%|███████▉  | 204/256 [00:06<00:01, 32.24it/s]

Loss: 1.425072873695946 


100%|██████████| 256/256 [00:07<00:00, 32.35it/s]


Loss: 1.5859685095074907 


 22%|██▏       | 56/256 [00:01<00:06, 32.69it/s]

Loss: 1.6621668857554597 


 41%|████      | 104/256 [00:03<00:04, 31.94it/s]

Loss: 1.475428016357244 


 61%|██████    | 156/256 [00:04<00:03, 31.97it/s]

Loss: 1.4753495042565634 


 80%|███████▉  | 204/256 [00:06<00:01, 32.48it/s]

Loss: 1.4110191028062307 


100%|██████████| 256/256 [00:07<00:00, 32.26it/s]

Loss: 1.5837175492963413 



 22%|██▏       | 56/256 [00:01<00:06, 32.22it/s]

Loss: 1.6858954159209867 


 41%|████      | 104/256 [00:03<00:04, 32.64it/s]

Loss: 1.4789022489234753 


 61%|██████    | 156/256 [00:04<00:03, 32.47it/s]

Loss: 1.4682613266741915 


 80%|███████▉  | 204/256 [00:06<00:01, 32.70it/s]

Loss: 1.3947666487722046 


100%|██████████| 256/256 [00:07<00:00, 32.41it/s]


Loss: 1.540826148897276 


 22%|██▏       | 56/256 [00:01<00:06, 32.05it/s]

Loss: 1.666987872498541 


 41%|████      | 104/256 [00:03<00:04, 31.92it/s]

Loss: 1.4657399141231684 


 61%|██████    | 156/256 [00:04<00:03, 32.75it/s]

Loss: 1.463443464369223 


 80%|███████▉  | 204/256 [00:06<00:01, 32.61it/s]

Loss: 1.3908180393607485 


100%|██████████| 256/256 [00:07<00:00, 32.36it/s]


Loss: 1.5438081880283112 


 22%|██▏       | 56/256 [00:01<00:06, 32.77it/s]

Loss: 1.6578792021781243 


 41%|████      | 104/256 [00:03<00:04, 32.68it/s]

Loss: 1.4590969104780809 


 61%|██████    | 156/256 [00:04<00:03, 31.95it/s]

Loss: 1.4594806013650576 


 80%|███████▉  | 204/256 [00:06<00:01, 32.65it/s]

Loss: 1.3816160110500701 


100%|██████████| 256/256 [00:07<00:00, 32.56it/s]


Loss: 1.5258470608882218 


 22%|██▏       | 56/256 [00:01<00:06, 32.53it/s]

Loss: 1.6592029653112477 


 41%|████      | 104/256 [00:03<00:04, 32.89it/s]

Loss: 1.4595049676299479 


 61%|██████    | 156/256 [00:04<00:03, 32.81it/s]

Loss: 1.4540031890706049 


 80%|███████▉  | 204/256 [00:06<00:01, 31.89it/s]

Loss: 1.383690960395111 


100%|██████████| 256/256 [00:07<00:00, 32.20it/s]


Loss: 1.5374702819100372 


 22%|██▏       | 56/256 [00:01<00:06, 32.39it/s]

Loss: 1.656707210397628 


 41%|████      | 104/256 [00:03<00:04, 31.91it/s]

Loss: 1.4658749018216595 


 61%|██████    | 156/256 [00:04<00:03, 32.20it/s]

Loss: 1.4594259805145027 


 80%|███████▉  | 204/256 [00:06<00:01, 32.77it/s]

Loss: 1.37765985069257 


100%|██████████| 256/256 [00:07<00:00, 32.38it/s]

Loss: 1.533413795145947 

Duration: 396 seconds





In [24]:
losses_masked = train_model(masked_model, epochs = 50, checkpoint = "masked_model", test = False)

 21%|██▏       | 55/256 [00:02<00:06, 28.77it/s]

Loss: 0.7125947646184057 


 40%|████      | 103/256 [00:03<00:05, 28.82it/s]

Loss: 0.4960687019302555 


 60%|██████    | 154/256 [00:05<00:03, 28.48it/s]

Loss: 0.5204139500801389 


 80%|████████  | 205/256 [00:07<00:01, 28.91it/s]

Loss: 0.5027937865457571 


 99%|█████████▉| 253/256 [00:08<00:00, 28.62it/s]

Loss: 0.4838830035254653 


100%|██████████| 256/256 [00:09<00:00, 28.39it/s]
  train_acc.append(np.sum(result_epoch) / np.array(result_epoch).size)
 21%|██        | 54/256 [00:01<00:07, 28.75it/s]

Loss: 0.6281159695865499 


 41%|████      | 105/256 [00:03<00:05, 28.95it/s]

Loss: 0.4048342608668419 


 60%|█████▉    | 153/256 [00:05<00:03, 28.93it/s]

Loss: 0.4484141398591294 


 80%|███████▉  | 204/256 [00:07<00:01, 28.65it/s]

Loss: 0.454119146077239 


100%|█████████▉| 255/256 [00:08<00:00, 28.36it/s]

Loss: 0.44157967976379775 


100%|██████████| 256/256 [00:08<00:00, 28.68it/s]
 21%|██        | 54/256 [00:01<00:06, 28.89it/s]

Loss: 0.5927855399739158 


 41%|████      | 105/256 [00:03<00:05, 26.48it/s]

Loss: 0.37974291408763067 


 60%|█████▉    | 153/256 [00:05<00:03, 28.11it/s]

Loss: 0.4262200608349337 


 80%|███████▉  | 204/256 [00:07<00:01, 27.72it/s]

Loss: 0.4080368966838743 


100%|█████████▉| 255/256 [00:09<00:00, 28.62it/s]

Loss: 0.42138964551181873 


100%|██████████| 256/256 [00:09<00:00, 27.98it/s]
 21%|██        | 54/256 [00:01<00:07, 28.38it/s]

Loss: 0.562102955200452 


 41%|████      | 105/256 [00:03<00:05, 28.13it/s]

Loss: 0.37196388806926173 


 60%|█████▉    | 153/256 [00:05<00:03, 28.17it/s]

Loss: 0.3913004965604825 


 80%|███████▉  | 204/256 [00:07<00:01, 28.35it/s]

Loss: 0.3911754629825967 


100%|█████████▉| 255/256 [00:08<00:00, 28.47it/s]

Loss: 0.3993803307591995 


100%|██████████| 256/256 [00:09<00:00, 28.34it/s]
 21%|██        | 54/256 [00:01<00:07, 28.61it/s]

Loss: 0.535948820834334 


 41%|████      | 105/256 [00:03<00:05, 28.37it/s]

Loss: 0.3418393860052183 


 60%|█████▉    | 153/256 [00:05<00:03, 28.77it/s]

Loss: 0.3732360859962112 


 80%|███████▉  | 204/256 [00:07<00:01, 28.70it/s]

Loss: 0.37353705492843603 


100%|█████████▉| 255/256 [00:08<00:00, 28.67it/s]

Loss: 0.3788747131571675 


100%|██████████| 256/256 [00:08<00:00, 28.56it/s]
 21%|██        | 54/256 [00:01<00:07, 28.49it/s]

Loss: 0.5293244587800161 


 41%|████      | 105/256 [00:03<00:05, 28.66it/s]

Loss: 0.32774216078722435 


 60%|█████▉    | 153/256 [00:05<00:03, 28.60it/s]

Loss: 0.3448740920164946 


 80%|███████▉  | 204/256 [00:07<00:01, 28.66it/s]

Loss: 0.34280594047648805 


100%|█████████▉| 255/256 [00:08<00:00, 28.46it/s]

Loss: 0.37099602327526765 


100%|██████████| 256/256 [00:08<00:00, 28.45it/s]
 21%|██        | 54/256 [00:01<00:07, 28.07it/s]

Loss: 0.5149963112020949 


 41%|████      | 105/256 [00:03<00:05, 28.17it/s]

Loss: 0.3238267769379333 


 60%|█████▉    | 153/256 [00:05<00:03, 28.40it/s]

Loss: 0.33984560432444527 


 80%|███████▉  | 204/256 [00:07<00:01, 27.70it/s]

Loss: 0.3291014418778154 


100%|█████████▉| 255/256 [00:09<00:00, 28.37it/s]

Loss: 0.37694163349507015 


100%|██████████| 256/256 [00:09<00:00, 28.28it/s]
 21%|██        | 54/256 [00:01<00:07, 28.30it/s]

Loss: 0.5068339500522462 


 41%|████      | 105/256 [00:03<00:05, 28.51it/s]

Loss: 0.31514828893264873 


 60%|█████▉    | 153/256 [00:05<00:03, 28.54it/s]

Loss: 0.3405587793673492 


 80%|███████▉  | 204/256 [00:07<00:01, 28.41it/s]

Loss: 0.3385222844132915 


100%|█████████▉| 255/256 [00:08<00:00, 28.56it/s]

Loss: 0.3644714167064385 


100%|██████████| 256/256 [00:09<00:00, 28.33it/s]
 21%|██        | 54/256 [00:01<00:07, 28.70it/s]

Loss: 0.48415423548862613 


 41%|████      | 105/256 [00:03<00:05, 28.64it/s]

Loss: 0.3051218234263709 


 60%|█████▉    | 153/256 [00:05<00:03, 28.78it/s]

Loss: 0.34545455570475947 


 80%|███████▉  | 204/256 [00:07<00:01, 28.62it/s]

Loss: 0.33660816204637534 


100%|█████████▉| 255/256 [00:08<00:00, 27.87it/s]

Loss: 0.36293870435048325 


100%|██████████| 256/256 [00:08<00:00, 28.57it/s]
 21%|██        | 54/256 [00:01<00:07, 28.49it/s]

Loss: 0.4901834599409821 


 41%|████      | 105/256 [00:03<00:05, 27.47it/s]

Loss: 0.3009212657929534 


 60%|█████▉    | 153/256 [00:05<00:03, 28.11it/s]

Loss: 0.3174434673804976 


 80%|███████▉  | 204/256 [00:07<00:01, 28.47it/s]

Loss: 0.3225168352732172 


100%|█████████▉| 255/256 [00:09<00:00, 28.49it/s]

Loss: 0.36717361452590475 


100%|██████████| 256/256 [00:09<00:00, 28.25it/s]
 21%|██        | 54/256 [00:01<00:07, 28.37it/s]

Loss: 0.4699395628681191 


 41%|████      | 105/256 [00:03<00:05, 28.35it/s]

Loss: 0.30238378216557765 


 60%|█████▉    | 153/256 [00:05<00:03, 28.12it/s]

Loss: 0.31446834508406235 


 80%|███████▉  | 204/256 [00:07<00:01, 28.09it/s]

Loss: 0.317687111452357 


100%|█████████▉| 255/256 [00:08<00:00, 28.55it/s]

Loss: 0.3755824889771947 


100%|██████████| 256/256 [00:09<00:00, 28.39it/s]
 21%|██        | 54/256 [00:01<00:07, 28.32it/s]

Loss: 0.46409684099397486 


 41%|████      | 105/256 [00:03<00:05, 28.03it/s]

Loss: 0.29581765505359636 


 60%|█████▉    | 153/256 [00:05<00:03, 28.64it/s]

Loss: 0.3262896457350728 


 80%|███████▉  | 204/256 [00:07<00:01, 28.69it/s]

Loss: 0.3198991197861257 


100%|█████████▉| 255/256 [00:08<00:00, 28.75it/s]

Loss: 0.3543846634009362 


100%|██████████| 256/256 [00:08<00:00, 28.46it/s]
 21%|██        | 54/256 [00:01<00:07, 28.54it/s]

Loss: 0.4699158117775277 


 41%|████      | 105/256 [00:03<00:05, 28.50it/s]

Loss: 0.29678520022342986 


 60%|█████▉    | 153/256 [00:05<00:03, 28.52it/s]

Loss: 0.310494152714488 


 80%|███████▉  | 204/256 [00:07<00:01, 28.35it/s]

Loss: 0.3190247640033094 


100%|█████████▉| 255/256 [00:09<00:00, 28.19it/s]

Loss: 0.3500124592320907 


100%|██████████| 256/256 [00:09<00:00, 28.21it/s]
 21%|██        | 54/256 [00:01<00:07, 27.84it/s]

Loss: 0.45443970027618547 


 41%|████      | 105/256 [00:03<00:05, 28.07it/s]

Loss: 0.2887009004188444 


 60%|█████▉    | 153/256 [00:05<00:03, 28.62it/s]

Loss: 0.3079707898951695 


 80%|███████▉  | 204/256 [00:07<00:01, 28.31it/s]

Loss: 0.31530163175806725 


100%|█████████▉| 255/256 [00:09<00:00, 28.40it/s]

Loss: 0.3503073567587968 


100%|██████████| 256/256 [00:09<00:00, 28.25it/s]
 21%|██        | 54/256 [00:01<00:07, 28.58it/s]

Loss: 0.4464305894753338 


 41%|████      | 105/256 [00:03<00:05, 28.68it/s]

Loss: 0.283051305654384 


 60%|█████▉    | 153/256 [00:05<00:03, 28.44it/s]

Loss: 0.3039765695256879 


 80%|███████▉  | 204/256 [00:07<00:01, 28.44it/s]

Loss: 0.3133342710038363 


100%|█████████▉| 255/256 [00:08<00:00, 28.48it/s]

Loss: 0.3476883987764341 


100%|██████████| 256/256 [00:09<00:00, 28.41it/s]
 21%|██        | 54/256 [00:01<00:07, 28.66it/s]

Loss: 0.4430719962684798 


 41%|████      | 105/256 [00:03<00:05, 28.10it/s]

Loss: 0.2836628531618772 


 60%|█████▉    | 153/256 [00:05<00:03, 28.16it/s]

Loss: 0.29166878426387444 


 80%|███████▉  | 204/256 [00:07<00:01, 28.50it/s]

Loss: 0.31037047981745525 


100%|█████████▉| 255/256 [00:08<00:00, 28.44it/s]

Loss: 0.3478168729731895 


100%|██████████| 256/256 [00:09<00:00, 28.33it/s]
 21%|██        | 54/256 [00:01<00:07, 28.66it/s]

Loss: 0.4404592932095742 


 41%|████      | 105/256 [00:03<00:05, 27.50it/s]

Loss: 0.28436827980084245 


 60%|█████▉    | 153/256 [00:05<00:03, 28.39it/s]

Loss: 0.2951258103306086 


 80%|███████▉  | 204/256 [00:07<00:01, 28.24it/s]

Loss: 0.31265033941419373 


100%|█████████▉| 255/256 [00:09<00:00, 28.60it/s]

Loss: 0.34758463316639365 


100%|██████████| 256/256 [00:09<00:00, 28.18it/s]
 21%|██        | 54/256 [00:01<00:07, 28.39it/s]

Loss: 0.4437729370506597 


 41%|████      | 105/256 [00:03<00:05, 28.54it/s]

Loss: 0.3268462279187206 


 60%|█████▉    | 153/256 [00:05<00:03, 28.48it/s]

Loss: 0.3518849996297401 


 80%|███████▉  | 204/256 [00:07<00:01, 28.24it/s]

Loss: 0.3197154173080886 


100%|█████████▉| 255/256 [00:08<00:00, 28.70it/s]

Loss: 0.35470915603049036 


100%|██████████| 256/256 [00:08<00:00, 28.45it/s]
 21%|██        | 54/256 [00:01<00:07, 28.48it/s]

Loss: 0.44674419206192895 


 41%|████      | 105/256 [00:03<00:05, 28.33it/s]

Loss: 0.28837622471871827 


 60%|█████▉    | 153/256 [00:05<00:03, 28.58it/s]

Loss: 0.3017641703236311 


 80%|███████▉  | 204/256 [00:07<00:01, 28.71it/s]

Loss: 0.3143592926683219 


100%|█████████▉| 255/256 [00:08<00:00, 28.48it/s]

Loss: 0.3498186617779938 


100%|██████████| 256/256 [00:09<00:00, 28.43it/s]
 21%|██        | 54/256 [00:01<00:07, 28.62it/s]

Loss: 0.4392550664020733 


 41%|████      | 105/256 [00:03<00:05, 28.36it/s]

Loss: 0.28217444801144803 


 60%|█████▉    | 153/256 [00:05<00:03, 28.26it/s]

Loss: 0.28888639812556116 


 80%|███████▉  | 204/256 [00:07<00:01, 28.21it/s]

Loss: 0.3094731181793742 


100%|█████████▉| 255/256 [00:09<00:00, 28.45it/s]

Loss: 0.34608539103839786 


100%|██████████| 256/256 [00:09<00:00, 28.17it/s]
 21%|██        | 54/256 [00:01<00:07, 28.31it/s]

Loss: 0.4376034042614102 


 41%|████      | 105/256 [00:03<00:05, 28.52it/s]

Loss: 0.28073139405147773 


 60%|█████▉    | 153/256 [00:05<00:03, 28.40it/s]

Loss: 0.2872308473877686 


 80%|███████▉  | 204/256 [00:07<00:01, 28.70it/s]

Loss: 0.30848592546887 


100%|█████████▉| 255/256 [00:08<00:00, 28.83it/s]

Loss: 0.34587074135203916 


100%|██████████| 256/256 [00:08<00:00, 28.48it/s]
 21%|██        | 54/256 [00:01<00:07, 28.49it/s]

Loss: 0.4371685541885954 


 41%|████      | 105/256 [00:03<00:05, 28.33it/s]

Loss: 0.28039984266072016 


 60%|█████▉    | 153/256 [00:05<00:03, 28.52it/s]

Loss: 0.28576504463413566 


 80%|███████▉  | 204/256 [00:07<00:01, 28.28it/s]

Loss: 0.3081470809492171 


100%|█████████▉| 255/256 [00:08<00:00, 28.76it/s]

Loss: 0.345508714844543 


100%|██████████| 256/256 [00:08<00:00, 28.45it/s]
 21%|██        | 54/256 [00:01<00:07, 27.99it/s]

Loss: 0.43593590221802114 


 41%|████      | 105/256 [00:03<00:05, 28.10it/s]

Loss: 0.27961159199589014 


 60%|█████▉    | 153/256 [00:05<00:03, 28.17it/s]

Loss: 0.2847306138181463 


 80%|███████▉  | 204/256 [00:07<00:01, 27.75it/s]

Loss: 0.30776041655429154 


100%|█████████▉| 255/256 [00:09<00:00, 27.89it/s]

Loss: 0.34544370732360485 


100%|██████████| 256/256 [00:09<00:00, 27.97it/s]
 21%|██        | 54/256 [00:01<00:07, 27.61it/s]

Loss: 0.43544510134284314 


 41%|████      | 105/256 [00:03<00:05, 27.21it/s]

Loss: 0.2793383072087305 


 60%|█████▉    | 153/256 [00:05<00:03, 28.21it/s]

Loss: 0.28448611854780187 


 80%|███████▉  | 204/256 [00:07<00:01, 28.14it/s]

Loss: 0.30755902152550946 


100%|█████████▉| 255/256 [00:09<00:00, 27.97it/s]

Loss: 0.3447481837094786 


100%|██████████| 256/256 [00:09<00:00, 27.81it/s]
 21%|██        | 54/256 [00:01<00:07, 27.83it/s]

Loss: 0.4348112066826971 


 41%|████      | 105/256 [00:03<00:05, 27.54it/s]

Loss: 0.2791355821963874 


 60%|█████▉    | 153/256 [00:05<00:03, 28.42it/s]

Loss: 0.28407854823623796 


 80%|███████▉  | 204/256 [00:07<00:01, 28.06it/s]

Loss: 0.3072882001308841 


100%|█████████▉| 255/256 [00:09<00:00, 28.30it/s]

Loss: 0.34464052004309126 


100%|██████████| 256/256 [00:09<00:00, 27.97it/s]
 21%|██        | 54/256 [00:01<00:07, 27.89it/s]

Loss: 0.4346148593671954 


 41%|████      | 105/256 [00:03<00:05, 28.35it/s]

Loss: 0.2791475218087927 


 60%|█████▉    | 153/256 [00:05<00:03, 28.24it/s]

Loss: 0.2836790487860073 


 80%|███████▉  | 204/256 [00:07<00:01, 28.18it/s]

Loss: 0.3071344880369124 


100%|█████████▉| 255/256 [00:09<00:00, 27.98it/s]

Loss: 0.3445459137931553 


100%|██████████| 256/256 [00:09<00:00, 28.05it/s]
 21%|██        | 54/256 [00:01<00:07, 27.83it/s]

Loss: 0.4345459451680565 


 41%|████      | 105/256 [00:03<00:05, 28.31it/s]

Loss: 0.2789482628848869 


 60%|█████▉    | 153/256 [00:05<00:03, 28.02it/s]

Loss: 0.2835206443940348 


 80%|███████▉  | 204/256 [00:07<00:01, 26.35it/s]

Loss: 0.3071306245092699 


100%|█████████▉| 255/256 [00:09<00:00, 27.71it/s]

Loss: 0.34443544380281976 


100%|██████████| 256/256 [00:09<00:00, 27.79it/s]
 21%|██        | 54/256 [00:01<00:07, 28.21it/s]

Loss: 0.4344042057937588 


 41%|████      | 105/256 [00:03<00:05, 28.27it/s]

Loss: 0.2787335516473326 


 60%|█████▉    | 153/256 [00:05<00:03, 28.27it/s]

Loss: 0.28354650504035506 


 80%|███████▉  | 204/256 [00:07<00:01, 28.05it/s]

Loss: 0.30722425511866225 


100%|█████████▉| 255/256 [00:09<00:00, 27.98it/s]

Loss: 0.3443412772172384 


100%|██████████| 256/256 [00:09<00:00, 28.03it/s]
 21%|██        | 54/256 [00:01<00:07, 28.37it/s]

Loss: 0.44167722385896196 


 41%|████      | 105/256 [00:03<00:05, 28.52it/s]

Loss: 0.28908522412173016 


 60%|█████▉    | 153/256 [00:05<00:03, 28.16it/s]

Loss: 0.3010125591319524 


 80%|███████▉  | 204/256 [00:07<00:01, 27.73it/s]

Loss: 0.3213289621862806 


100%|█████████▉| 255/256 [00:09<00:00, 28.49it/s]

Loss: 0.35359156743586484 


100%|██████████| 256/256 [00:09<00:00, 28.14it/s]
 21%|██        | 54/256 [00:01<00:07, 28.15it/s]

Loss: 0.46335698726959573 


 41%|████      | 105/256 [00:03<00:05, 28.27it/s]

Loss: 0.3097488798854326 


 60%|█████▉    | 153/256 [00:05<00:03, 28.16it/s]

Loss: 0.3076478627604777 


 80%|███████▉  | 204/256 [00:07<00:01, 27.86it/s]

Loss: 0.31333727546260415 


100%|█████████▉| 255/256 [00:09<00:00, 28.37it/s]

Loss: 0.3670432150352349 


100%|██████████| 256/256 [00:09<00:00, 28.03it/s]
 21%|██        | 54/256 [00:01<00:07, 27.87it/s]

Loss: 0.4462308597929335 


 41%|████      | 105/256 [00:03<00:05, 28.15it/s]

Loss: 0.28254672739879266 


 60%|█████▉    | 153/256 [00:05<00:03, 28.05it/s]

Loss: 0.2900989810220131 


 80%|███████▉  | 204/256 [00:07<00:01, 28.25it/s]

Loss: 0.3110277272741406 


100%|█████████▉| 255/256 [00:09<00:00, 28.39it/s]

Loss: 0.3277557118127062 


100%|██████████| 256/256 [00:09<00:00, 28.03it/s]
 21%|██        | 54/256 [00:01<00:07, 28.39it/s]

Loss: 0.43675395328190697 


 41%|████      | 105/256 [00:03<00:05, 27.75it/s]

Loss: 0.28017220966971046 


 60%|█████▉    | 153/256 [00:05<00:03, 27.76it/s]

Loss: 0.28586694436703075 


 80%|███████▉  | 204/256 [00:07<00:01, 28.56it/s]

Loss: 0.30817111548821186 


100%|█████████▉| 255/256 [00:09<00:00, 28.27it/s]

Loss: 0.32784480361699536 


100%|██████████| 256/256 [00:09<00:00, 28.12it/s]
 21%|██        | 54/256 [00:01<00:07, 27.74it/s]

Loss: 0.4349073269499574 


 41%|████      | 105/256 [00:03<00:05, 27.78it/s]

Loss: 0.2789162484877233 


 60%|█████▉    | 153/256 [00:05<00:03, 27.83it/s]

Loss: 0.2844224066161004 


 80%|███████▉  | 204/256 [00:07<00:01, 28.41it/s]

Loss: 0.30745471140214986 


100%|█████████▉| 255/256 [00:09<00:00, 28.28it/s]

Loss: 0.3259974483222227 


100%|██████████| 256/256 [00:09<00:00, 28.17it/s]
 21%|██        | 54/256 [00:01<00:07, 27.79it/s]

Loss: 0.4346074834457909 


 41%|████      | 105/256 [00:03<00:05, 27.59it/s]

Loss: 0.2786510337260828 


 60%|█████▉    | 153/256 [00:05<00:03, 28.33it/s]

Loss: 0.28444424484548997 


 80%|███████▉  | 204/256 [00:07<00:01, 28.29it/s]

Loss: 0.3074075636703095 


100%|█████████▉| 255/256 [00:09<00:00, 28.16it/s]

Loss: 0.3256669005880055 


100%|██████████| 256/256 [00:09<00:00, 28.00it/s]
 21%|██        | 54/256 [00:01<00:07, 27.96it/s]

Loss: 0.43446168213993647 


 41%|████      | 105/256 [00:03<00:05, 27.79it/s]

Loss: 0.2785635396444175 


 60%|█████▉    | 153/256 [00:05<00:03, 28.60it/s]

Loss: 0.2839239347515281 


 80%|███████▉  | 204/256 [00:07<00:01, 28.54it/s]

Loss: 0.30716772248040936 


100%|█████████▉| 255/256 [00:09<00:00, 28.44it/s]

Loss: 0.3257279754561461 


100%|██████████| 256/256 [00:09<00:00, 28.26it/s]
 21%|██        | 54/256 [00:01<00:07, 27.69it/s]

Loss: 0.4343872214010067 


 41%|████      | 105/256 [00:03<00:05, 28.01it/s]

Loss: 0.27822175777782554 


 60%|█████▉    | 153/256 [00:05<00:03, 28.26it/s]

Loss: 0.28364250483237574 


 80%|███████▉  | 204/256 [00:07<00:01, 28.62it/s]

Loss: 0.3071227512007799 


100%|█████████▉| 255/256 [00:09<00:00, 27.79it/s]

Loss: 0.32552164904754366 


100%|██████████| 256/256 [00:09<00:00, 28.18it/s]
 21%|██        | 54/256 [00:01<00:07, 27.96it/s]

Loss: 0.4342625975231119 


 41%|████      | 105/256 [00:03<00:05, 28.50it/s]

Loss: 0.2778434268188354 


 60%|█████▉    | 153/256 [00:05<00:03, 28.13it/s]

Loss: 0.2835232816567688 


 80%|███████▉  | 204/256 [00:07<00:01, 28.44it/s]

Loss: 0.30706826692155237 


100%|█████████▉| 255/256 [00:09<00:00, 27.98it/s]

Loss: 0.3255469361609855 


100%|██████████| 256/256 [00:09<00:00, 28.22it/s]
 21%|██        | 54/256 [00:01<00:07, 26.67it/s]

Loss: 0.4342789744560511 


 41%|████      | 105/256 [00:03<00:05, 28.50it/s]

Loss: 0.27788380778967 


 60%|█████▉    | 153/256 [00:05<00:03, 28.20it/s]

Loss: 0.2834455221458896 


 80%|███████▉  | 204/256 [00:07<00:01, 28.26it/s]

Loss: 0.307138558055979 


100%|█████████▉| 255/256 [00:09<00:00, 27.92it/s]

Loss: 0.32538742289637246 


100%|██████████| 256/256 [00:09<00:00, 28.05it/s]
 21%|██        | 54/256 [00:01<00:07, 28.15it/s]

Loss: 0.43427817820234443 


 41%|████      | 105/256 [00:03<00:05, 28.32it/s]

Loss: 0.2777348650000199 


 60%|█████▉    | 153/256 [00:05<00:03, 28.61it/s]

Loss: 0.2834582445857779 


 80%|███████▉  | 204/256 [00:07<00:01, 28.08it/s]

Loss: 0.3070288847122892 


100%|█████████▉| 255/256 [00:09<00:00, 28.14it/s]

Loss: 0.3253851096857616 


100%|██████████| 256/256 [00:09<00:00, 28.06it/s]
 21%|██        | 54/256 [00:01<00:07, 28.43it/s]

Loss: 0.43413186037836016 


 41%|████      | 105/256 [00:03<00:05, 28.52it/s]

Loss: 0.27767897837106514 


 60%|█████▉    | 153/256 [00:05<00:03, 28.21it/s]

Loss: 0.283472399955943 


 80%|███████▉  | 204/256 [00:07<00:01, 27.80it/s]

Loss: 0.30702860417772676 


100%|█████████▉| 255/256 [00:09<00:00, 27.91it/s]

Loss: 0.3253540515415569 


100%|██████████| 256/256 [00:09<00:00, 28.10it/s]
 21%|██        | 54/256 [00:01<00:07, 28.46it/s]

Loss: 0.43406572884230704 


 41%|████      | 105/256 [00:03<00:05, 28.43it/s]

Loss: 0.2776481465250274 


 60%|█████▉    | 153/256 [00:05<00:03, 28.10it/s]

Loss: 0.28332165832851236 


 80%|███████▉  | 204/256 [00:07<00:01, 27.61it/s]

Loss: 0.3069688033318958 


100%|█████████▉| 255/256 [00:09<00:00, 28.20it/s]

Loss: 0.32535515660740005 


100%|██████████| 256/256 [00:09<00:00, 28.03it/s]
 21%|██        | 54/256 [00:01<00:07, 28.04it/s]

Loss: 0.43405416582452006 


 41%|████      | 105/256 [00:03<00:05, 28.47it/s]

Loss: 0.27763314932234595 


 60%|█████▉    | 153/256 [00:05<00:03, 28.51it/s]

Loss: 0.2832786718393305 


 80%|███████▉  | 204/256 [00:07<00:01, 28.06it/s]

Loss: 0.3070207093204714 


100%|█████████▉| 255/256 [00:09<00:00, 27.98it/s]

Loss: 0.3252970475760229 


100%|██████████| 256/256 [00:09<00:00, 28.22it/s]
 21%|██        | 54/256 [00:01<00:07, 28.59it/s]

Loss: 0.4340447123833414 


 41%|████      | 105/256 [00:03<00:05, 28.30it/s]

Loss: 0.27761095881506437 


 60%|█████▉    | 153/256 [00:05<00:03, 27.88it/s]

Loss: 0.28292427974040624 


 80%|███████▉  | 204/256 [00:07<00:01, 27.43it/s]

Loss: 0.3076816615465198 


100%|█████████▉| 255/256 [00:09<00:00, 28.49it/s]

Loss: 0.32570107070931376 


100%|██████████| 256/256 [00:09<00:00, 28.15it/s]
 21%|██        | 54/256 [00:01<00:07, 28.57it/s]

Loss: 0.43448505484173205 


 41%|████      | 105/256 [00:03<00:05, 27.98it/s]

Loss: 0.2780761819107244 


 60%|█████▉    | 153/256 [00:05<00:03, 27.97it/s]

Loss: 0.2832381809405987 


 80%|███████▉  | 204/256 [00:07<00:01, 27.84it/s]

Loss: 0.30776290987778593 


100%|█████████▉| 255/256 [00:09<00:00, 28.44it/s]

Loss: 0.32827996488978983 


100%|██████████| 256/256 [00:09<00:00, 28.07it/s]
 21%|██        | 54/256 [00:01<00:07, 27.29it/s]

Loss: 0.437286971396493 


 41%|████      | 105/256 [00:03<00:05, 28.20it/s]

Loss: 0.2915489351435773 


 60%|█████▉    | 153/256 [00:05<00:03, 27.85it/s]

Loss: 0.299452305820077 


 80%|███████▉  | 204/256 [00:07<00:01, 28.71it/s]

Loss: 0.31086550783960637 


100%|█████████▉| 255/256 [00:09<00:00, 28.60it/s]

Loss: 0.34182722379507896 


100%|██████████| 256/256 [00:09<00:00, 28.06it/s]
 21%|██        | 54/256 [00:01<00:07, 28.37it/s]

Loss: 0.4377925354778847 


 41%|████      | 105/256 [00:03<00:05, 27.83it/s]

Loss: 0.27986768264119183 


 60%|█████▉    | 153/256 [00:05<00:03, 27.89it/s]

Loss: 0.2855324128638773 


 80%|███████▉  | 204/256 [00:07<00:01, 28.46it/s]

Loss: 0.3075999211968645 


100%|█████████▉| 255/256 [00:09<00:00, 28.58it/s]

Loss: 0.3258462760358277 


100%|██████████| 256/256 [00:09<00:00, 28.16it/s]
 21%|██        | 54/256 [00:01<00:07, 28.33it/s]

Loss: 0.4390522307094705 


 41%|████      | 105/256 [00:03<00:05, 27.83it/s]

Loss: 0.2784331656751959 


 60%|█████▉    | 153/256 [00:05<00:03, 27.69it/s]

Loss: 0.28403107856843696 


 80%|███████▉  | 204/256 [00:07<00:01, 28.52it/s]

Loss: 0.30719919098612886 


100%|█████████▉| 255/256 [00:09<00:00, 28.32it/s]

Loss: 0.3254458170669715 


100%|██████████| 256/256 [00:09<00:00, 28.18it/s]
 21%|██        | 54/256 [00:01<00:07, 28.39it/s]

Loss: 0.4346231511078843 


 41%|████      | 105/256 [00:03<00:05, 27.69it/s]

Loss: 0.27815209218864256 


 60%|█████▉    | 153/256 [00:05<00:03, 26.51it/s]

Loss: 0.28394748166472494 


 80%|███████▉  | 204/256 [00:07<00:01, 28.27it/s]

Loss: 0.3072542363827902 


100%|█████████▉| 255/256 [00:09<00:00, 28.18it/s]

Loss: 0.32575634005090875 


100%|██████████| 256/256 [00:09<00:00, 27.89it/s]
 21%|██        | 54/256 [00:01<00:07, 27.90it/s]

Loss: 0.4342666852060229 


 41%|████      | 105/256 [00:03<00:05, 27.91it/s]

Loss: 0.2778437426357774 


 60%|█████▉    | 153/256 [00:05<00:03, 28.13it/s]

Loss: 0.2833987387893492 


 80%|███████▉  | 204/256 [00:07<00:01, 28.00it/s]

Loss: 0.30701708445971354 


100%|█████████▉| 255/256 [00:09<00:00, 28.03it/s]

Loss: 0.3253774964478449 


100%|██████████| 256/256 [00:09<00:00, 28.00it/s]
 21%|██        | 54/256 [00:01<00:07, 27.77it/s]

Loss: 0.43417605404524734 


 41%|████      | 105/256 [00:03<00:05, 28.06it/s]

Loss: 0.2776508168191838 


 60%|█████▉    | 153/256 [00:05<00:03, 28.65it/s]

Loss: 0.28291002229406575 


 80%|███████▉  | 204/256 [00:07<00:01, 28.44it/s]

Loss: 0.3070672051339471 


100%|█████████▉| 255/256 [00:09<00:00, 28.08it/s]

Loss: 0.32531790337928945 


100%|██████████| 256/256 [00:09<00:00, 28.10it/s]


Duration: 454 seconds





# test

In [39]:
from sklearn.metrics import classification_report, confusion_matrix
def test(model):
    loss, y_true, y_predict = test_model(model)
    print("test loss: ", loss / len(test_set))
    for key in y_predict.keys():
        if len(y_predict[key]) > 1:
            print(key)
            print(classification_report(y_true[key], y_predict[key]))

In [40]:
test(skin_race_model)

100%|██████████| 3062/3062 [00:10<00:00, 283.72it/s]


Duration: 11 seconds
test loss:  2.5538007240720546
skintone
              precision    recall  f1-score   support

           0       0.43      0.72      0.54       730
           1       0.94      0.69      0.79      2095
           2       0.34      0.38      0.36       172
           3       0.44      0.49      0.46        65

   micro avg       0.68      0.67      0.68      3062
   macro avg       0.54      0.57      0.54      3062
weighted avg       0.77      0.67      0.70      3062
 samples avg       0.67      0.67      0.67      3062

race
              precision    recall  f1-score   support

           0       0.94      0.64      0.77      1431
           1       0.77      0.94      0.84      1495
           2       0.45      0.78      0.57       136

   micro avg       0.80      0.80      0.80      3062
   macro avg       0.72      0.79      0.73      3062
weighted avg       0.83      0.80      0.80      3062
 samples avg       0.80      0.80      0.80      3062




  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [41]:
test(gender_model)

100%|██████████| 3062/3062 [00:07<00:00, 410.58it/s]



Duration: 7 seconds
test loss:  0.4208441230399713
gender
              precision    recall  f1-score   support

           0       0.86      0.77      0.81       975
           1       0.90      0.94      0.92      2087

   micro avg       0.89      0.89      0.89      3062
   macro avg       0.88      0.86      0.87      3062
weighted avg       0.89      0.89      0.89      3062
 samples avg       0.89      0.89      0.89      3062



In [42]:
test(masked_model)

100%|██████████| 3062/3062 [00:07<00:00, 393.78it/s]


Duration: 8 seconds
test loss:  0.32239204365820917
masked
              precision    recall  f1-score   support

           0       0.99      1.00      1.00      2957
           1       0.89      0.86      0.87       105

   micro avg       0.99      0.99      0.99      3062
   macro avg       0.94      0.93      0.93      3062
weighted avg       0.99      0.99      0.99      3062
 samples avg       0.99      0.99      0.99      3062






In [None]:
masked_model.att_out = True

In [None]:
#mask 3
index = 24
with torch.no_grad():

    y, att = masked_model(train_set[index][0].unsqueeze(0).to(device))
    img = train_set[index][0].permute(1, 2, 0).numpy()
    print(y)
    att = att.cpu().numpy().reshape(7, 7)
    plt.subplot(1, 2, 1)
    plt.imshow(att)
    plt.subplot(1, 2, 2)
    plt.imshow(img)