In [71]:
import os
import json
from collections import OrderedDict

import torch
from torch import nn
from torch import optim
from tqdm.notebook import tqdm

from torch.utils.data import Dataset, DataLoader
from typing import Dict, List, Tuple, Any

import numpy as np
from matplotlib import pyplot as plt

Connect to google drive

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## 1) Load the data

In [72]:
data_dir = '/content/drive/My Drive/datasets/2009-skoltech-hack/data'

# read the data
data_path = os.path.join(data_dir, 'train_data_scaled.json')
with open(data_path, 'r') as json_file:
    data = json.load(json_file)

# read the labels
labels_path = os.path.join(data_dir, 'train_labels.json')
with open(labels_path, 'r') as json_file:
    labels = json.load(json_file)

print('Loaded')

Loaded


## 2) Define Dataset and DataLoader

In [73]:
product_counts = np.array([7418, 30983,  7375, 32560, 29704,  2652,  7720], dtype=np.float)
product_probs = 1 / (product_counts**(1.5))
product_probs /= np.sum(product_probs)

class ProductDataset(Dataset):
    def __init__(self, 
                 data: Dict[int, List[Any]], 
                 labels: Dict[int, List[int]]):
        """
        data: Dict with structure: {client_id: List of features}
        labels: Dict with structure: {client_id: List with 7 zeros or ones}
        """
        super(ProductDataset, self).__init__()
        self.data = data
        self.labels = labels
        
        self.ix_to_key = dict()
        
        for ix, key in enumerate(self.data):
            self.ix_to_key[ix] = key
            

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        
        cust_id = self.ix_to_key[idx] # get customer id
        
        products = np.zeros(7, dtype=np.float32)
        for label_ix, product_status in enumerate(self.labels[cust_id]):
            if product_status:
                products[label_ix] = 1
                
        # randomly close one product
        ones_ixs = np.where(products == 1)[0]
        ones_probs = product_probs[ones_ixs]
        ones_probs /= np.sum(ones_probs)
        chosen_ix = np.random.choice(ones_ixs, replace=False, p = ones_probs)
        
        # zero that position
        products[chosen_ix] = 0
        
        # form one-hot target
        targets = np.zeros_like(products, dtype=np.float32)
        targets[chosen_ix] = 1
        
        item = {
            "features": {},
            "targets": targets,
        }

        features = np.array(self.data[cust_id])
        features = np.hstack((features, products))
        
        item["features"] = torch.from_numpy(features).float()


        return item

In [74]:
BATCH_SIZE = 16

product_dataset = ProductDataset(data=data, 
                                 labels=labels)

product_dataloader = DataLoader(dataset=product_dataset, 
                                batch_size = BATCH_SIZE, 
                                shuffle=True, 
                                num_workers = 4)

In [75]:
product_dataset[2]['features'].shape

torch.Size([169])

## 3) Define the model

In [76]:
params = {
    'input_dim': 162+7,
    'emb_dim': 64,
    'transformer_nhead': 2,
    'transformer_dim_feedforward': 32,
    'transformer_dropout': 0.1,
    'dense_unit': 128,
    'num_layers': 2,
    'n_products': 7
}

In [77]:
class ProductModel(nn.Module):
    def __init__(self, params):
        super().__init__()

        # Let's say embedding
        self.embedding = nn.Linear(in_features=params['input_dim'], 
                                   out_features=params["emb_dim"])

        transformer_blocks = []
        for i in range(params["num_layers"]):
            transformer_block = nn.TransformerEncoderLayer(
                d_model=params["emb_dim"],
                nhead=params["transformer_nhead"],
                dim_feedforward=params["transformer_dim_feedforward"],
                dropout=params["transformer_dropout"],
            )
            transformer_blocks.append(
                (f"transformer_block_{i}", transformer_block)
            )

        self.transformer_encoder = nn.Sequential(
            OrderedDict(transformer_blocks)
        )

        self.linear = nn.Linear(
            in_features=params["emb_dim"], out_features=params["dense_unit"]
        )
        self.scorer = nn.Linear(
            in_features=params["dense_unit"],
            out_features=params["n_products"],
        )

    def forward(self, features):

        emb_features = self.embedding(features).unsqueeze(1)

        transformer_output = self.transformer_encoder(emb_features)
        pooling = torch.mean(transformer_output, dim=1)
        linear = torch.tanh(self.linear(pooling))
        merch_logits = self.scorer(linear)

        return merch_logits

### Sanity check

In [78]:
model = ProductModel(params=params)

criterion = nn.BCEWithLogitsLoss()
batch = next(iter(product_dataloader))
output = model(batch['features'])

loss = criterion(output, batch['targets'])
print(loss)

tensor(0.6913, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)


## 4) Train loop without Catalyst

In [79]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [80]:
def calculate_metrics(logits: torch.tensor, 
                      targets: torch.tensor, 
                      num_classes: int = 7):
    """
    Calculates position-wise count matched
    """

    logits = logits.detach().cpu().numpy()
    prediction = np.argmax(logits, axis=1)
    prediction_one_hot = np.zeros((prediction.size, num_classes))
    prediction_one_hot[np.arange(prediction.size), prediction] = 1
    prediction_one_hot = prediction_one_hot.astype(bool)

    targets = targets.detach().cpu().numpy().astype(bool)

    matches = np.logical_and(prediction_one_hot, targets).astype(int)
    count_matches = np.sum(matches, axis=0)
    
    return count_matches

In [81]:
def train(model, 
          train_dataloader, 
          criterion,
          optimizer,
          device: str = 'cuda', 
          num_epochs: int = 10, 
          num_classes: int = 7, 
          eps = 1e-6):
    """
    model: torch nn.Module
    train_dataloader: torch DataLoader
    """

    all_losses = []
    all_counts = []

    model = model.to(device)
    best_mean_acc = 0
    for epoch in range(num_epochs):
        running_loss = 0
        
        num_samples_per_epoch = np.zeros(num_classes, dtype=np.float64) + eps
        class_wise_counts = np.zeros(num_classes, dtype = int) 
        for batch in tqdm(train_dataloader):
            # train body
            optimizer.zero_grad()

            features = batch['features'].to(device)
            targets = batch['targets'].to(device)

            logits = model(features)

            loss = criterion(logits, batch['targets'].to(device))
            loss.backward()

            
            optimizer.step()

            # save metrics
            with torch.no_grad():
                running_loss += loss.item()
                class_wise_counts += calculate_metrics(logits=logits, targets=targets)

                all_losses.append(loss.item())
                
                num_samples_per_epoch += targets.detach().cpu().numpy().sum(axis=0)
                all_counts.append(class_wise_counts / num_samples_per_epoch)

        # epoch metrics

        running_loss /= len(train_dataloader)
        
        
        class_wise_counts = class_wise_counts.astype(np.float64) + eps
        class_wise_counts /= num_samples_per_epoch
        mean_acc = np.mean(class_wise_counts)

        print(f'EPOCH: {epoch + 1}')
        print(f'### BSE loss: {running_loss:.5f}')
        print('#### num samples: ', '\t'.join([str(round(elem)) for elem in num_samples_per_epoch]))
        print(f'###### Accuracy:', '\t'.join([str(round(elem, 5)) for elem in class_wise_counts]), f'Mean: {mean_acc:.4f}')
        print()
        if mean_acc > best_mean_acc:
          best_mean_acc = mean_acc
          torch.save(model.state_dict(), os.path.join(data_dir, f"prod_scaled_epoch_{epoch+1}_ckpt.pth"))

    return model, all_losses, all_counts


In [82]:
model = ProductModel(params=params)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr = 3*1e-4)

trained_model, losses, counts = train(model=model, 
                                      train_dataloader=product_dataloader, 
                                      criterion=criterion, 
                                      optimizer=optimizer, 
                                      device=device, 
                                      num_epochs=100)

HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 1
### BSE loss: 0.29939
#### num samples:  4669.0	12390.0	4613.0	10358.0	9779.0	2336.0	5280.0
###### Accuracy: 0.04348	0.68709	0.24865	0.51458	0.71531	0.15111	0.23352 Mean: 0.3705



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 2
### BSE loss: 0.25048
#### num samples:  4681.0	12450.0	4543.0	10483.0	9644.0	2335.0	5289.0
###### Accuracy: 0.1019	0.66803	0.31125	0.61652	0.90087	0.28522	0.35394 Mean: 0.4625



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 3
### BSE loss: 0.24485
#### num samples:  4729.0	12362.0	4527.0	10359.0	9881.0	2333.0	5234.0
###### Accuracy: 0.15162	0.66745	0.33643	0.62014	0.89991	0.28847	0.35709 Mean: 0.4744



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 4
### BSE loss: 0.24104
#### num samples:  4692.0	12403.0	4584.0	10350.0	9792.0	2352.0	5252.0
###### Accuracy: 0.16837	0.68008	0.35231	0.63488	0.89277	0.29932	0.36101 Mean: 0.4841



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 5
### BSE loss: 0.23830
#### num samples:  4761.0	12436.0	4493.0	10390.0	9797.0	2356.0	5192.0
###### Accuracy: 0.19513	0.69001	0.35299	0.64167	0.88384	0.31367	0.36017 Mean: 0.4911



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 6
### BSE loss: 0.23636
#### num samples:  4669.0	12388.0	4587.0	10479.0	9733.0	2354.0	5215.0
###### Accuracy: 0.18762	0.68122	0.37846	0.65073	0.88534	0.3254	0.37085 Mean: 0.4971



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 7
### BSE loss: 0.23448
#### num samples:  4753.0	12422.0	4554.0	10336.0	9759.0	2332.0	5269.0
###### Accuracy: 0.20219	0.69184	0.38625	0.65538	0.87652	0.31432	0.3811 Mean: 0.5011



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 8
### BSE loss: 0.23386
#### num samples:  4757.0	12384.0	4520.0	10455.0	9708.0	2340.0	5261.0
###### Accuracy: 0.21085	0.69412	0.36571	0.65806	0.87907	0.31154	0.37502 Mean: 0.4992



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 9
### BSE loss: 0.23200
#### num samples:  4722.0	12401.0	4548.0	10403.0	9751.0	2336.0	5264.0
###### Accuracy: 0.20542	0.69696	0.38105	0.66058	0.87899	0.32063	0.38868 Mean: 0.5046



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 10
### BSE loss: 0.23179
#### num samples:  4706.0	12512.0	4574.0	10348.0	9655.0	2346.0	5284.0
###### Accuracy: 0.20166	0.70564	0.40577	0.65414	0.87758	0.32907	0.38664 Mean: 0.5086



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 11
### BSE loss: 0.23043
#### num samples:  4729.0	12418.0	4578.0	10388.0	9749.0	2333.0	5230.0
###### Accuracy: 0.21421	0.70213	0.39799	0.67251	0.8764	0.33262	0.38623 Mean: 0.5117



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 12
### BSE loss: 0.22871
#### num samples:  4717.0	12413.0	4567.0	10452.0	9680.0	2353.0	5243.0
###### Accuracy: 0.21688	0.7012	0.39391	0.68188	0.87593	0.34084	0.38814 Mean: 0.5141



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 13
### BSE loss: 0.22745
#### num samples:  4755.0	12483.0	4536.0	10278.0	9763.0	2359.0	5251.0
###### Accuracy: 0.23344	0.71722	0.37787	0.68126	0.87719	0.35947	0.37974 Mean: 0.5180



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 14
### BSE loss: 0.22589
#### num samples:  4698.0	12454.0	4608.0	10398.0	9716.0	2337.0	5214.0
###### Accuracy: 0.21371	0.71953	0.39931	0.68167	0.8803	0.36671	0.39873 Mean: 0.5229



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 15
### BSE loss: 0.22419
#### num samples:  4738.0	12386.0	4567.0	10379.0	9745.0	2337.0	5273.0
###### Accuracy: 0.23217	0.72606	0.39282	0.6914	0.88179	0.37013	0.39181 Mean: 0.5266



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 16
### BSE loss: 0.22315
#### num samples:  4667.0	12443.0	4569.0	10472.0	9651.0	2335.0	5288.0
###### Accuracy: 0.22327	0.72675	0.39724	0.68669	0.88188	0.37816	0.38559 Mean: 0.5257



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 17
### BSE loss: 0.22217
#### num samples:  4714.0	12431.0	4585.0	10425.0	9683.0	2339.0	5248.0
###### Accuracy: 0.21426	0.73333	0.40174	0.69659	0.87969	0.39119	0.39787 Mean: 0.5307



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 18
### BSE loss: 0.22077
#### num samples:  4721.0	12397.0	4553.0	10360.0	9786.0	2335.0	5273.0
###### Accuracy: 0.22495	0.73873	0.40325	0.69556	0.88013	0.39101	0.39769 Mean: 0.5330



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 19
### BSE loss: 0.21948
#### num samples:  4714.0	12425.0	4532.0	10478.0	9700.0	2336.0	5240.0
###### Accuracy: 0.2255	0.74262	0.40159	0.69374	0.8799	0.3917	0.39523 Mean: 0.5329



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 20
### BSE loss: 0.21945
#### num samples:  4731.0	12503.0	4553.0	10414.0	9673.0	2332.0	5219.0
###### Accuracy: 0.22659	0.74822	0.39644	0.69531	0.88153	0.39537	0.39874 Mean: 0.5346



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 21
### BSE loss: 0.21929
#### num samples:  4718.0	12478.0	4573.0	10350.0	9720.0	2332.0	5254.0
###### Accuracy: 0.23188	0.74195	0.39646	0.69932	0.88374	0.38422	0.40084 Mean: 0.5341



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 22
### BSE loss: 0.21772
#### num samples:  4628.0	12598.0	4541.0	10414.0	9672.0	2342.0	5230.0
###### Accuracy: 0.21521	0.74877	0.39617	0.70453	0.88131	0.38898	0.40822 Mean: 0.5347



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 23
### BSE loss: 0.21710
#### num samples:  4760.0	12327.0	4556.0	10476.0	9744.0	2341.0	5221.0
###### Accuracy: 0.23782	0.75193	0.41176	0.7037	0.88383	0.39513	0.38996 Mean: 0.5392



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 24
### BSE loss: 0.21618
#### num samples:  4709.0	12522.0	4511.0	10309.0	9750.0	2363.0	5261.0
###### Accuracy: 0.24591	0.75188	0.39415	0.71122	0.88195	0.40753	0.39441 Mean: 0.5410



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 25
### BSE loss: 0.21640
#### num samples:  4724.0	12423.0	4605.0	10313.0	9786.0	2341.0	5233.0
###### Accuracy: 0.24344	0.74797	0.40239	0.70794	0.888	0.3994	0.39499 Mean: 0.5406



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 26
### BSE loss: 0.21529
#### num samples:  4649.0	12476.0	4529.0	10499.0	9680.0	2356.0	5236.0
###### Accuracy: 0.24048	0.75385	0.40053	0.71026	0.88357	0.41087	0.40565 Mean: 0.5436



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 27
### BSE loss: 0.21548
#### num samples:  4755.0	12381.0	4545.0	10362.0	9751.0	2344.0	5287.0
###### Accuracy: 0.23449	0.75212	0.39978	0.71617	0.88391	0.41766	0.40552 Mean: 0.5442



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 28
### BSE loss: 0.21433
#### num samples:  4683.0	12488.0	4533.0	10427.0	9691.0	2344.0	5259.0
###### Accuracy: 0.23297	0.75488	0.41121	0.72073	0.88474	0.40828	0.41072 Mean: 0.5462



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 29
### BSE loss: 0.21414
#### num samples:  4680.0	12412.0	4565.0	10524.0	9613.0	2361.0	5270.0
###### Accuracy: 0.23504	0.75161	0.40175	0.72073	0.88047	0.41762	0.41385 Mean: 0.5459



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 30
### BSE loss: 0.21403
#### num samples:  4733.0	12370.0	4550.0	10483.0	9706.0	2333.0	5250.0
###### Accuracy: 0.25269	0.74988	0.41758	0.71602	0.87729	0.4192	0.40762 Mean: 0.5486



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 31
### BSE loss: 0.21365
#### num samples:  4776.0	12402.0	4537.0	10395.0	9721.0	2333.0	5261.0
###### Accuracy: 0.26047	0.75206	0.40489	0.71669	0.87769	0.41449	0.40068 Mean: 0.5467



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 32
### BSE loss: 0.21302
#### num samples:  4683.0	12408.0	4596.0	10435.0	9736.0	2340.0	5227.0
###### Accuracy: 0.23959	0.75516	0.4084	0.72008	0.88157	0.42393	0.40348 Mean: 0.5475



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 33
### BSE loss: 0.21292
#### num samples:  4695.0	12410.0	4533.0	10487.0	9697.0	2311.0	5292.0
###### Accuracy: 0.25921	0.75898	0.40194	0.7207	0.88522	0.42536	0.41742 Mean: 0.5527



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 34
### BSE loss: 0.21171
#### num samples:  4688.0	12413.0	4585.0	10500.0	9634.0	2329.0	5276.0
###### Accuracy: 0.25533	0.7551	0.42857	0.72867	0.88375	0.41477	0.41926 Mean: 0.5551



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 35
### BSE loss: 0.21187
#### num samples:  4695.0	12396.0	4599.0	10324.0	9827.0	2340.0	5244.0
###### Accuracy: 0.23898	0.7575	0.42509	0.71745	0.88379	0.42479	0.42201 Mean: 0.5528



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 36
### BSE loss: 0.21165
#### num samples:  4683.0	12513.0	4541.0	10388.0	9711.0	2348.0	5241.0
###### Accuracy: 0.24706	0.75945	0.41775	0.71756	0.88323	0.43526	0.41919 Mean: 0.5542



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 37
### BSE loss: 0.21119
#### num samples:  4740.0	12379.0	4575.0	10484.0	9596.0	2345.0	5306.0
###### Accuracy: 0.25886	0.75644	0.42361	0.72291	0.87985	0.43198	0.41934 Mean: 0.5561



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 38
### BSE loss: 0.21034
#### num samples:  4631.0	12423.0	4590.0	10505.0	9623.0	2358.0	5295.0
###### Accuracy: 0.24746	0.76069	0.42658	0.72565	0.87696	0.4419	0.42247 Mean: 0.5574



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 39
### BSE loss: 0.21031
#### num samples:  4681.0	12500.0	4520.0	10433.0	9662.0	2347.0	5282.0
###### Accuracy: 0.25059	0.76296	0.4104	0.72625	0.87891	0.43204	0.42332 Mean: 0.5549



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 40
### BSE loss: 0.21014
#### num samples:  4633.0	12460.0	4595.0	10430.0	9731.0	2337.0	5239.0
###### Accuracy: 0.25728	0.75963	0.43373	0.72886	0.88049	0.43688	0.4163 Mean: 0.5590



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 41
### BSE loss: 0.20992
#### num samples:  4778.0	12448.0	4546.0	10315.0	9819.0	2360.0	5159.0
###### Accuracy: 0.27355	0.75442	0.42059	0.72622	0.88512	0.44619	0.41229 Mean: 0.5598



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 42
### BSE loss: 0.20948
#### num samples:  4772.0	12489.0	4585.0	10387.0	9658.0	2343.0	5191.0
###### Accuracy: 0.28248	0.75779	0.43141	0.73034	0.87917	0.45284	0.41129 Mean: 0.5636



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 43
### BSE loss: 0.20907
#### num samples:  4707.0	12497.0	4532.0	10418.0	9726.0	2350.0	5195.0
###### Accuracy: 0.26939	0.75938	0.42189	0.72759	0.88094	0.44213	0.42098 Mean: 0.5603



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 44
### BSE loss: 0.20922
#### num samples:  4685.0	12485.0	4590.0	10516.0	9586.0	2342.0	5221.0
###### Accuracy: 0.25315	0.75899	0.43094	0.7307	0.884	0.44065	0.41544 Mean: 0.5591



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 45
### BSE loss: 0.20881
#### num samples:  4721.0	12408.0	4501.0	10484.0	9701.0	2313.0	5297.0
###### Accuracy: 0.25588	0.75935	0.41902	0.72644	0.88383	0.45179	0.43194 Mean: 0.5612



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 46
### BSE loss: 0.20812
#### num samples:  4725.0	12510.0	4486.0	10355.0	9756.0	2334.0	5259.0
###### Accuracy: 0.26878	0.76011	0.42399	0.7325	0.88038	0.43787	0.42727 Mean: 0.5616



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 47
### BSE loss: 0.20797
#### num samples:  4747.0	12390.0	4540.0	10472.0	9728.0	2318.0	5230.0
###### Accuracy: 0.2707	0.76554	0.42467	0.72489	0.88446	0.44176	0.42543 Mean: 0.5625



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 48
### BSE loss: 0.20748
#### num samples:  4745.0	12388.0	4480.0	10578.0	9650.0	2344.0	5240.0
###### Accuracy: 0.27313	0.7609	0.41696	0.73294	0.88466	0.4552	0.43034 Mean: 0.5649



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 49
### BSE loss: 0.20829
#### num samples:  4766.0	12436.0	4557.0	10355.0	9718.0	2334.0	5259.0
###### Accuracy: 0.27696	0.75692	0.41914	0.73153	0.88094	0.46015	0.42461 Mean: 0.5643



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 50
### BSE loss: 0.20688
#### num samples:  4676.0	12411.0	4596.0	10513.0	9630.0	2351.0	5248.0
###### Accuracy: 0.26989	0.76319	0.4332	0.73538	0.88058	0.45427	0.43178 Mean: 0.5669



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 51
### BSE loss: 0.20769
#### num samples:  4712.0	12499.0	4547.0	10351.0	9723.0	2325.0	5268.0
###### Accuracy: 0.27059	0.7619	0.43963	0.72853	0.87936	0.44731	0.42502 Mean: 0.5646



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 52
### BSE loss: 0.20709
#### num samples:  4642.0	12434.0	4654.0	10423.0	9714.0	2345.0	5213.0
###### Accuracy: 0.25118	0.76106	0.448	0.72858	0.88233	0.46908	0.42624 Mean: 0.5666



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 53
### BSE loss: 0.20726
#### num samples:  4762.0	12418.0	4564.0	10339.0	9770.0	2354.0	5218.0
###### Accuracy: 0.28496	0.76623	0.42572	0.72483	0.88454	0.4554	0.43369 Mean: 0.5679



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 54
### BSE loss: 0.20659
#### num samples:  4711.0	12366.0	4585.0	10494.0	9615.0	2345.0	5309.0
###### Accuracy: 0.28126	0.76629	0.43577	0.73547	0.88237	0.45757	0.43568 Mean: 0.5706



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 55
### BSE loss: 0.20673
#### num samples:  4773.0	12509.0	4554.0	10376.0	9620.0	2332.0	5261.0
###### Accuracy: 0.29122	0.76425	0.44357	0.72918	0.88025	0.45926	0.42596 Mean: 0.5705



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 56
### BSE loss: 0.20606
#### num samples:  4730.0	12452.0	4562.0	10472.0	9682.0	2334.0	5193.0
###### Accuracy: 0.2871	0.76213	0.42964	0.73902	0.88174	0.47686	0.4248 Mean: 0.5716



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 57
### BSE loss: 0.20550
#### num samples:  4772.0	12396.0	4498.0	10390.0	9809.0	2322.0	5238.0
###### Accuracy: 0.29568	0.76742	0.43419	0.73465	0.88062	0.45995	0.4412 Mean: 0.5734



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 58
### BSE loss: 0.20551
#### num samples:  4662.0	12377.0	4608.0	10555.0	9651.0	2332.0	5240.0
###### Accuracy: 0.27456	0.76949	0.45247	0.73453	0.88115	0.46355	0.43359 Mean: 0.5728



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 59
### BSE loss: 0.20545
#### num samples:  4681.0	12419.0	4586.0	10479.0	9692.0	2330.0	5238.0
###### Accuracy: 0.28156	0.76375	0.44854	0.73471	0.88073	0.46567	0.42497 Mean: 0.5714



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 60
### BSE loss: 0.20556
#### num samples:  4783.0	12422.0	4557.0	10461.0	9626.0	2332.0	5244.0
###### Accuracy: 0.2812	0.76493	0.4424	0.73224	0.8798	0.46312	0.43268 Mean: 0.5709



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 61
### BSE loss: 0.20473
#### num samples:  4708.0	12404.0	4511.0	10457.0	9685.0	2374.0	5286.0
###### Accuracy: 0.2859	0.77153	0.43671	0.7309	0.87806	0.48189	0.44646 Mean: 0.5759



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 62
### BSE loss: 0.20449
#### num samples:  4714.0	12437.0	4561.0	10426.0	9704.0	2318.0	5265.0
###### Accuracy: 0.28532	0.76546	0.44574	0.73432	0.8783	0.46333	0.45033 Mean: 0.5747



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 63
### BSE loss: 0.20434
#### num samples:  4752.0	12451.0	4552.0	10318.0	9708.0	2340.0	5304.0
###### Accuracy: 0.29356	0.77287	0.44332	0.73493	0.87691	0.47051	0.44551 Mean: 0.5768



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 64
### BSE loss: 0.20374
#### num samples:  4683.0	12506.0	4575.0	10348.0	9676.0	2340.0	5297.0
###### Accuracy: 0.29105	0.76971	0.44743	0.7385	0.87298	0.47949	0.44516 Mean: 0.5778



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 65
### BSE loss: 0.20382
#### num samples:  4681.0	12459.0	4581.0	10456.0	9626.0	2357.0	5265.0
###### Accuracy: 0.29267	0.76989	0.45077	0.73106	0.87825	0.48791	0.44179 Mean: 0.5789



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 66
### BSE loss: 0.20350
#### num samples:  4688.0	12421.0	4533.0	10486.0	9679.0	2332.0	5286.0
###### Accuracy: 0.28498	0.76685	0.45202	0.73899	0.87984	0.48027	0.44041 Mean: 0.5776



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 67
### BSE loss: 0.20328
#### num samples:  4656.0	12435.0	4552.0	10478.0	9687.0	2354.0	5263.0
###### Accuracy: 0.28866	0.77209	0.44706	0.73249	0.87808	0.47706	0.45202 Mean: 0.5782



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 68
### BSE loss: 0.20285
#### num samples:  4721.0	12475.0	4617.0	10381.0	9649.0	2332.0	5250.0
###### Accuracy: 0.29104	0.77098	0.45961	0.73962	0.87709	0.48842	0.45448 Mean: 0.5830



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 69
### BSE loss: 0.20276
#### num samples:  4759.0	12440.0	4543.0	10433.0	9684.0	2323.0	5243.0
###### Accuracy: 0.30385	0.76905	0.45014	0.73689	0.87898	0.4817	0.44955 Mean: 0.5815



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 70
### BSE loss: 0.20276
#### num samples:  4660.0	12483.0	4525.0	10480.0	9677.0	2323.0	5277.0
###### Accuracy: 0.28884	0.77265	0.44597	0.73378	0.8825	0.48515	0.44021 Mean: 0.5784



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 71
### BSE loss: 0.20253
#### num samples:  4701.0	12437.0	4525.0	10455.0	9702.0	2307.0	5298.0
###### Accuracy: 0.301	0.77808	0.45039	0.73553	0.88085	0.48461	0.4547 Mean: 0.5836



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 72
### BSE loss: 0.20210
#### num samples:  4727.0	12492.0	4524.0	10463.0	9650.0	2351.0	5218.0
###### Accuracy: 0.30908	0.77354	0.45115	0.73707	0.88228	0.48958	0.45286 Mean: 0.5851



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 73
### BSE loss: 0.20205
#### num samples:  4693.0	12384.0	4622.0	10384.0	9768.0	2355.0	5219.0
###### Accuracy: 0.30023	0.77059	0.46733	0.74162	0.88462	0.49342	0.43552 Mean: 0.5848



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 74
### BSE loss: 0.20202
#### num samples:  4702.0	12545.0	4572.0	10374.0	9681.0	2327.0	5224.0
###### Accuracy: 0.30221	0.7733	0.45297	0.74147	0.87439	0.49162	0.44391 Mean: 0.5828



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 75
### BSE loss: 0.20191
#### num samples:  4667.0	12492.0	4588.0	10441.0	9664.0	2292.0	5281.0
###### Accuracy: 0.31348	0.7745	0.45445	0.73489	0.87976	0.48342	0.43382 Mean: 0.5820



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 76
### BSE loss: 0.20093
#### num samples:  4687.0	12484.0	4528.0	10435.0	9670.0	2355.0	5266.0
###### Accuracy: 0.3083	0.77619	0.45075	0.74231	0.87973	0.50106	0.45044 Mean: 0.5870



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 77
### BSE loss: 0.20129
#### num samples:  4691.0	12518.0	4506.0	10486.0	9618.0	2345.0	5261.0
###### Accuracy: 0.30335	0.77281	0.45761	0.74175	0.87721	0.49638	0.44972 Mean: 0.5855



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 78
### BSE loss: 0.20110
#### num samples:  4763.0	12468.0	4563.0	10375.0	9652.0	2336.0	5268.0
###### Accuracy: 0.30653	0.77462	0.4554	0.74217	0.87837	0.50685	0.44723 Mean: 0.5873



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 79
### BSE loss: 0.20136
#### num samples:  4736.0	12407.0	4601.0	10360.0	9789.0	2312.0	5220.0
###### Accuracy: 0.30279	0.77086	0.46164	0.73948	0.8816	0.50908	0.44368 Mean: 0.5870



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 80
### BSE loss: 0.20074
#### num samples:  4735.0	12485.0	4554.0	10388.0	9729.0	2320.0	5214.0
###### Accuracy: 0.31679	0.78174	0.45147	0.73806	0.88149	0.48922	0.4509 Mean: 0.5871



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 81
### BSE loss: 0.19968
#### num samples:  4688.0	12371.0	4602.0	10361.0	9818.0	2340.0	5245.0
###### Accuracy: 0.30887	0.77698	0.47479	0.73786	0.88603	0.51026	0.45338 Mean: 0.5926



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 82
### BSE loss: 0.19976
#### num samples:  4668.0	12507.0	4543.0	10444.0	9683.0	2357.0	5223.0
###### Accuracy: 0.30206	0.77892	0.46379	0.73995	0.87638	0.50912	0.44227 Mean: 0.5875



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 83
### BSE loss: 0.20032
#### num samples:  4737.0	12416.0	4577.0	10462.0	9658.0	2321.0	5254.0
###### Accuracy: 0.31243	0.77159	0.46799	0.74441	0.87793	0.50668	0.44747 Mean: 0.5898



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 84
### BSE loss: 0.19973
#### num samples:  4695.0	12547.0	4533.0	10439.0	9598.0	2346.0	5267.0
###### Accuracy: 0.31608	0.78098	0.45069	0.74155	0.87956	0.50469	0.44826 Mean: 0.5888



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 85
### BSE loss: 0.19939
#### num samples:  4678.0	12491.0	4569.0	10463.0	9686.0	2336.0	5202.0
###### Accuracy: 0.30868	0.77616	0.46093	0.74367	0.88024	0.51113	0.44944 Mean: 0.5900



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 86
### BSE loss: 0.19977
#### num samples:  4740.0	12455.0	4540.0	10440.0	9714.0	2324.0	5212.0
###### Accuracy: 0.31329	0.77736	0.4674	0.73563	0.87698	0.50818	0.44724 Mean: 0.5894



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 87
### BSE loss: 0.19936
#### num samples:  4752.0	12428.0	4489.0	10492.0	9680.0	2352.0	5232.0
###### Accuracy: 0.32386	0.78042	0.45645	0.73599	0.88285	0.50765	0.45642 Mean: 0.5919



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 88
### BSE loss: 0.19918
#### num samples:  4777.0	12461.0	4550.0	10375.0	9727.0	2339.0	5196.0
###### Accuracy: 0.32196	0.78052	0.4622	0.7441	0.88136	0.51475	0.4542 Mean: 0.5942



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 89
### BSE loss: 0.19842
#### num samples:  4678.0	12506.0	4571.0	10433.0	9685.0	2357.0	5195.0
###### Accuracy: 0.3168	0.78642	0.46511	0.73958	0.88198	0.52015	0.45371 Mean: 0.5948



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 90
### BSE loss: 0.19888
#### num samples:  4775.0	12409.0	4570.0	10489.0	9573.0	2330.0	5279.0
###### Accuracy: 0.32126	0.77823	0.47352	0.74268	0.8782	0.51974	0.45368 Mean: 0.5953



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 91
### BSE loss: 0.19826
#### num samples:  4726.0	12365.0	4527.0	10615.0	9602.0	2340.0	5250.0
###### Accuracy: 0.31824	0.77784	0.47073	0.7496	0.87482	0.50684	0.45619 Mean: 0.5935



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 92
### BSE loss: 0.19896
#### num samples:  4722.0	12368.0	4579.0	10453.0	9656.0	2335.0	5312.0
###### Accuracy: 0.32317	0.78089	0.46844	0.74457	0.87956	0.51006	0.4599 Mean: 0.5952



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 93
### BSE loss: 0.19866
#### num samples:  4744.0	12391.0	4616.0	10370.0	9732.0	2327.0	5245.0
###### Accuracy: 0.31809	0.77855	0.48419	0.74185	0.88091	0.50881	0.4591 Mean: 0.5959



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 94
### BSE loss: 0.19777
#### num samples:  4695.0	12452.0	4525.0	10493.0	9625.0	2351.0	5284.0
###### Accuracy: 0.32758	0.78357	0.46961	0.74507	0.87543	0.52233	0.45609 Mean: 0.5971



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 95
### BSE loss: 0.19816
#### num samples:  4700.0	12384.0	4588.0	10370.0	9788.0	2357.0	5238.0
###### Accuracy: 0.31936	0.78093	0.47232	0.74282	0.87403	0.50573	0.4538 Mean: 0.5927



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 96
### BSE loss: 0.19761
#### num samples:  4666.0	12477.0	4570.0	10425.0	9721.0	2307.0	5259.0
###### Accuracy: 0.32726	0.78032	0.48337	0.74926	0.8783	0.51539	0.45788 Mean: 0.5988



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 97
### BSE loss: 0.19699
#### num samples:  4692.0	12360.0	4559.0	10511.0	9711.0	2332.0	5260.0
###### Accuracy: 0.32609	0.78519	0.47642	0.74246	0.88199	0.52444	0.46673 Mean: 0.6005



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 98
### BSE loss: 0.19713
#### num samples:  4668.0	12492.0	4511.0	10403.0	9753.0	2349.0	5249.0
###### Accuracy: 0.32562	0.78202	0.47639	0.74402	0.8785	0.52107	0.46504 Mean: 0.5990



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 99
### BSE loss: 0.19748
#### num samples:  4699.0	12472.0	4571.0	10440.0	9682.0	2318.0	5243.0
###### Accuracy: 0.32369	0.77983	0.47714	0.74033	0.87565	0.5289	0.46176 Mean: 0.5982



HBox(children=(FloatProgress(value=0.0, max=3090.0), HTML(value='')))


EPOCH: 100
### BSE loss: 0.19756
#### num samples:  4699.0	12385.0	4620.0	10481.0	9659.0	2362.0	5219.0
###### Accuracy: 0.32432	0.78571	0.4842	0.73953	0.88674	0.53514	0.46273 Mean: 0.6026

