In [1]:
import os.path as osp
import random
import pandas as pd
from torch_geometric.data import Data, Dataset, DataLoader
import tqdm
import glob
import pandas
import torch
import numpy as np
import multiprocessing as mp
import random
import plotly.graph_objects as go
from torch_geometric.data import Data, Dataset, DataLoader
import matplotlib._color_data as mcd
"libraries for debugging"
import sys
import os.path as osp
"custom imports"
from graphconv import GravnetModel
from optimized_oc import calc_LV_Lbeta

In [2]:
class ColorWheel:
    '''Returns a consistent color when given the same object'''
    def __init__(self, colors=None, seed=44):
        if colors is None:
            self.colors = list(mcd.XKCD_COLORS.values())
        else:
            self.colors = colors
        np.random.seed(seed)
        np.random.shuffle(self.colors)
        self._original_colors = self.colors.copy()
        self.assigned_colors = {}
        
    def __call__(self, thing):
        key = id(thing)
        if key in self.assigned_colors:
            return self.assigned_colors[key]
        else:
            color = self.colors.pop()
            self.assigned_colors[key] = color
            if not(self.colors): self.colors = self._original_colors.copy()
            return color
    
    def assign(self, thing, color):
        """Assigns a specific color to a thing"""
        key = id(thing)
        self.assigned_colors[key] = color
        if color in self.colors: self.colors.remove(color)


def get_plotly_clusterspace(event, 
                            cluster_space_coords, 
                            clustering=None,
                            size = 1.00):
    assert cluster_space_coords.size(1) == 3
    import plotly.graph_objects as go

    colorwheel = ColorWheel()
    colorwheel.assign(0, '#bfbfbf')
    colorwheel.assign(-1, '#bfbfbf')

    data = []

    if clustering is None: clustering = event.y

    for cluster_index in np.unique(clustering):
        x = cluster_space_coords[clustering == cluster_index].numpy()
        data.append(go.Scatter3d(
            x=x[:,0], y=x[:,1], z=x[:,2],
            mode='markers', 
            marker=dict(
                line=dict(width=0),
                size=size,
                color= colorwheel(int(cluster_index)),
                ),
            hovertemplate=(
                f'x=%{{y:0.2f}}<br>y=%{{z:0.2f}}<br>z=%{{x:0.2f}}'
                f'<br>clusterindex={cluster_index}'
                f'<br>'
                ),
            name = f'cluster_{cluster_index}'
            ))
    return data

def get_plotly_truth(event,
                    size = 1.00):
    colorwheel = ColorWheel()
    colorwheel.assign(0, '#bfbfbf')

    data = []

    for cluster_index in np.unique(event.y):
        x = event.x[event.y == cluster_index].numpy()
        data.append(go.Scatter3d(
            x=x[:,0], y=x[:,1], z=x[:,2],
#             x=x[:,3],y=x[:,-1],z=x[:,2],
            mode='lines+markers', 
            marker=dict(
                line=dict(width=0),
                size=size,
                color= colorwheel(int(cluster_index)),
                ),
            hovertemplate=(
                f'x=%{{y:0.2f}}<br>y=%{{z:0.2f}}<br>z=%{{x:0.2f}}'
                f'<br>clusterindex={cluster_index}'
                f'<br>'
                ),
            name = f'cluster_{cluster_index}'
            ))
    return data

In [3]:
class TrackMLParticleTrackingDataset(Dataset):
    def __init__(self, root, 
                 transform=None, 
                 n_events=0,
                 directed=False, 
                 layer_pairs_plus=False,
                 volume_layer_ids=[[8, 2], [8, 4], [8, 6], [8, 8]], #Layers Selecte
                 layer_pairs=[[7, 8], [8, 9], [9, 10]],             #Connected Layers
                 pt_min=2.0, 
                 eta_range=[-5, 5],                     
                 phi_slope_max=0.0006, 
                 z0_max=150,                  
                 n_phi_sections=1, 
                 n_eta_sections=1,  
                 augments = False,
                 tracking=False,                   
                 n_workers=mp.cpu_count(), 
                 n_tasks=1,               
                 download_full_dataset=False                        
             ):
        hits = glob.glob(osp.join(osp.join(root,'raw'), 'event*-hits.csv'))
        self.hits = sorted(hits)
        particles = glob.glob(osp.join(osp.join(root,'raw'), 'event*-particles.csv'))
        self.particles = sorted(particles)
        truth = glob.glob(osp.join(osp.join(root,'raw'), 'event*-truth.csv'))
        self.truth = sorted(truth)
        if (n_events > 0):
            self.hits = self.hits[:n_events]
            self.particles = self.particles[:n_events]
            self.truth = self.truth[:n_events]
        self.layer_pairs_plus = layer_pairs_plus
        self.volume_layer_ids = torch.tensor(volume_layer_ids)
        self.layer_pairs      = torch.tensor(layer_pairs)
        self.pt_min           = pt_min
        self.eta_range        = eta_range
        self.n_phi_sections   = n_phi_sections
        self.n_eta_sections   = n_eta_sections
        self.full_dataset     = download_full_dataset
        self.n_events         = n_events

#         self.phi_slope_max    = phi_slope_max
#         self.z0_max           = z0_max
#         self.augments         = augments
#         self.tracking         = tracking
#         self.n_tasks          = n_tasks

        super(TrackMLParticleTrackingDataset, self).__init__(root, transform)

    def len(self):
        N_events = len(self.hits)
        return N_events*self.n_phi_sections*self.n_eta_sections

    def __len__(self):
        N_events = len(self.hits)
        return N_events*self.n_phi_sections*self.n_eta_sections

    def read_events(self,idx):
        hits_filename = self.hits[idx]
        hits = pandas.read_csv(
            hits_filename, usecols=['hit_id', 'x', 'y', 'z', 'volume_id', 'layer_id', 'module_id'],
            dtype={
                'hit_id': np.int64,
                'x': np.float32,
                'y': np.float32,
                'z': np.float32,
                'volume_id': np.int64,
                'layer_id': np.int64,
                'module_id': np.int64
            })
        particles_filename = self.particles[idx]
        particles = pandas.read_csv(
            particles_filename, usecols=['particle_id', 'vx', 'vy', 'vz', 'px', 'py', 'pz', 'q', 'nhits'],
            dtype={
                'particle_id': np.int64,
                'vx': np.float32,
                'vy': np.float32,
                'vz': np.float32,
                'px': np.float32,
                'py': np.float32,
                'pz': np.float32,
                'q': np.int64,
                'nhits': np.int64
            })
        truth_filename = self.truth[idx]
        truth = pandas.read_csv(
            truth_filename, usecols=['hit_id', 'particle_id', 'tx', 'ty', 'tz', 'tpx', 'tpy', 'tpz', 'weight'],
            dtype={
                'hit_id': np.int64,
                'particle_id': np.int64,
                'tx': np.float32,
                'ty': np.float32,
                'tz': np.float32,
                'tpx': np.float32,
                'tpy': np.float32,
                'tpz': np.float32,
                'weight': np.float32
            })
        return hits,particles,truth

#     def select_hits(self, hits,particles,truth, noise_label = -1):

#         valid_layer = 20 * self.volume_layer_ids[:,0] + self.volume_layer_ids[:,1]
#         n_det_layers = len(valid_layer)

#         # Add markers for valid layers and corresponding indices to the hits
#         layer = torch.from_numpy(20 * hits['volume_id'].values + hits['layer_id'].values)
#         index = layer.unique(return_inverse=True)[1]
#         hits = hits[['hit_id', 'x', 'y', 'z']].assign(layer=layer, index=index)
#         valid_groups = hits.groupby(['layer'])
#         hits = pd.concat([valid_groups.get_group(valid_layer.numpy()[i]) for i in range(n_det_layers)])

#         hits = (hits[['hit_id', 'x', 'y', 'z', 'index']].merge(truth[['hit_id', 'particle_id']], on='hit_id'))
#         hits['track_id'] = hits['hit_id'].astype(str) + "-" + hits['particle_id'].astype(str)

#         # Compute other characteristics for the hits [r,phi, theta, eta]
#         r = np.sqrt(hits['x'].values**2 + hits['y'].values**2)
#         phi = np.arctan2(hits['y'].values, hits['x'].values)
#         theta = np.arctan2(r,hits['z'].values)
#         eta = -1*np.log(np.tan(theta/2))
#         hits = hits[['track_id','x','y','z', 'index', 'particle_id']].assign(r=r, 
#                                                                          phi=phi, 
#                                                                          eta=eta, 
#                                                                          theta = theta)
#         # Computing the counts of tracks/ hits associated with each particle 
#         hit_counts = hits.groupby(by=['particle_id']).size().reset_index(name='counts')
#         hits = hits.merge(hit_counts[['counts','particle_id']],on='particle_id')

#         # get the noisy hits in a separate dataframe and assign them noise_label (cluster label for noise)
#         noise_hits = hits[hits.particle_id == 0]
#         if noise_label != 0:
#             noise_hits.replace(0,noise_label)
#         noise_hits.insert(0, 'pt', 0.0) # add pt = 0.0 as a nominal column 

#         # Compute pt for the particles
#         particles['pt'] = np.sqrt(particles['px']**2 + particles['py']**2)

#         # Merge hits with particles to select only valid particles
#         hits = hits.merge(particles[['particle_id','pt']], on='particle_id')

#         selected_hits = pd.concat([hits,noise_hits])

#         # Mark all the particles where associated hit counts < 2 or pt <= pt_min as noise
#         selected_hits['particle_id'].where((selected_hits['counts'] >= 2) & (selected_hits['pt'] > self.pt_min),noise_label,inplace=True)

#         # Compute the remapped ids to make the labels contiguous
#         pids_unique, pids_inverse, _ = np.unique(selected_hits['particle_id'].values, return_inverse=True, return_counts=True)  
#         pids_unique = np.arange(pids_unique.size) 
#         selected_hits['remapped_pid'] = pids_unique[pids_inverse]

#         selected_ids = [0,1,2,3,4]
#         selected_hits = selected_hits[selected_hits.remapped_pid.isin(selected_ids)]
        
#         # Extract the features from the selected_hits
#         x = torch.from_numpy(selected_hits['x'].values)
#         y = torch.from_numpy(selected_hits['y'].values)
#         theta = torch.from_numpy(selected_hits['theta'].values)
#         r = torch.from_numpy(selected_hits['r'].values)
#         phi = torch.from_numpy(selected_hits['phi'].values)
#         z = torch.from_numpy(selected_hits['z'].values)
#         eta = torch.from_numpy(selected_hits['eta'].values)
#     #    layer = torch.from_numpy(selected_hits['index'].values)
#     #    particle = torch.from_numpy(selected_hits['particle_id'].values)
#         particle_labels = torch.from_numpy(selected_hits['remapped_pid'].values)
#         # Select only a set of labels for testing

#         pos = torch.stack([x, y, z, r, theta, phi], 1)   

#        return pos, eta, particle_labels

    def select_all_hits(self,hits,particles,truth, noise_label = -1):
        valid_layer = 20 * self.volume_layer_ids[:,0] + self.volume_layer_ids[:,1]
        n_det_layers = len(valid_layer)

        # Add markers for valid layers and corresponding indices to the hits
        layer = torch.from_numpy(20 * hits['volume_id'].values + hits['layer_id'].values)
        index = layer.unique(return_inverse=True)[1]
        hits = hits[['hit_id', 'x', 'y', 'z']].assign(layer=layer, index=index)
        valid_groups = hits.groupby(['layer'])
        hits = pd.concat([valid_groups.get_group(valid_layer.numpy()[i]) for i in range(n_det_layers)])
        hits = (hits[['hit_id', 'x', 'y', 'z', 'index']].merge(truth[['hit_id', 'particle_id']], on='hit_id'))
        hits['track_id'] = hits['hit_id'].astype(str) + "-" + hits['particle_id'].astype(str)
        # Compute other characteristics for the hits [r,phi, theta, eta]
        r = np.sqrt(hits['x'].values**2 + hits['y'].values**2)
        phi = np.arctan2(hits['y'].values, hits['x'].values)
        theta = np.arctan2(r,hits['z'].values)
        eta = -1*np.log(np.tan(theta/2))
        hits = hits[['track_id','x','y','z', 'index', 'particle_id']].assign(r=r, 
                                                                         phi=phi, 
                                                                         eta=eta, 
                                                                         theta = theta)
#         NO LONGER NEEDED TO COMPUTE COUNT; NOISE IDENTIFIED BY nhits
#         hit_counts = hits.groupby(by=['particle_id']).size().reset_index(name='counts')
#         hits = hits.merge(hit_counts[['counts','particle_id']],on='particle_id')

        # Compute pt for the particles
        particles['pt'] = np.sqrt(particles['px']**2 + particles['py']**2)

        # Merge hits with particles keeping all entities in the hits table
        selected_hits = hits.merge(particles[['particle_id','pt','nhits']], on='particle_id', how = "left")
        selected_hits['nhits']=selected_hits['nhits'].fillna(0)
        selected_hits['pt']=selected_hits['pt'].fillna(0)

        # Mark all the particles where associated (nhits < 2) AND (pt <= pt_min) as noise
        selected_hits['particle_id'].where((selected_hits['nhits'] >= 2) & (selected_hits['pt'] > self.pt_min),noise_label,inplace=True)

        # Compute the remapped ids to make the labels contiguous
        pids_unique, pids_inverse, _ = np.unique(selected_hits['particle_id'].values, return_inverse=True, return_counts=True)  
        pids_unique = np.arange(pids_unique.size) 
        selected_hits['remapped_pid'] = pids_unique[pids_inverse]

        # Select a subset noise + tracks to be selected as signals
        noise = selected_hits[selected_hits.remapped_pid == 0]
        idx = random.sample(noise.index.to_list(),50)
        idx.sort()
        noise = noise.loc[idx]
        
        signal_tracks = [1,2,3,4]
        signals = selected_hits[selected_hits.remapped_pid.isin(signal_tracks)]
        
        selected_hits = pd.concat([signals,noise])

        # Extract the features from the selected_hits
        x = torch.from_numpy(selected_hits['x'].values)
        y = torch.from_numpy(selected_hits['y'].values)
        theta = torch.from_numpy(selected_hits['theta'].values)
        r = torch.from_numpy(selected_hits['r'].values)
        phi = torch.from_numpy(selected_hits['phi'].values)
        z = torch.from_numpy(selected_hits['z'].values)
        eta = torch.from_numpy(selected_hits['eta'].values)
        particle_labels = torch.from_numpy(selected_hits['remapped_pid'].values)
        pos = torch.stack([x, y, z, r, theta, phi], 1)   

        return pos, eta, particle_labels
    
    def split_detector_sections(self,pos, eta, particle_labels, phi_edges, eta_edges):
        pos_sect, particle_label_sect = [], []
        # Refer to the index of the column representing phi values in pos tensor
        phi_idx = -1
        for i in range(len(phi_edges) - 1):
            phi_mask1 = pos[:,phi_idx] > phi_edges[i]
            phi_mask2 = pos[:,phi_idx] < phi_edges[i+1]
            phi_mask  = phi_mask1 & phi_mask2
            phi_pos      = pos[phi_mask]
            phi_eta      = eta[phi_mask]
            phi_particle_label = particle_labels[phi_mask]

            for j in range(len(eta_edges) - 1):
                eta_mask1 = phi_eta > eta_edges[j]
                eta_mask2 = phi_eta < eta_edges[j+1]
                eta_mask  = eta_mask1 & eta_mask2
                phi_eta_pos = phi_pos[eta_mask]
                phi_eta_particle_label = phi_particle_label[eta_mask]
                pos_sect.append(phi_eta_pos)
                particle_label_sect.append(phi_eta_particle_label)

        return pos_sect, particle_label_sect
    
    def get(self,idx):
        
        hits,particles,truth = self.read_events(idx)   
        pos, eta, particle_labels = self.select_all_hits(hits, 
                                                         particles, 
                                                         truth,noise_label=0)
        tracks = torch.empty(0, 5, dtype=torch.long)  
        phi_edges = np.linspace(*(-np.pi, np.pi), num=self.n_phi_sections+1)
        eta_edges = np.linspace(*self.eta_range, num=self.n_eta_sections+1)
        pos_sect, particle_label_sect = self.split_detector_sections(pos, 
                                                                    eta,
                                                                    particle_labels, 
                                                                    phi_edges, 
                                                                    eta_edges)
        for i in range(len(pos_sect)):
            y = particle_label_sect[0]
            return Data(x=pos_sect[0],
                        y=y,
                        tracks=tracks,
                        inpz = torch.Tensor([i]))

def fetch_dataloader(data_dir, 
                     batch_size, 
                     validation_split,
                     n_events = 100,
                     pt_min = 1.0,
                     n_workers = 1,
                     generate_tracks = True,
                     full_dataset = False,
                     shuffle=False):
    volume_layer_ids = [
        [8, 2], [8, 4], [8, 6], [8, 8], # barrel pixels
        [7, 2], [7, 4], [7, 6], [7, 8], [7, 10], [7, 12], [7, 14],# minus pixel endcap
        [9, 2], [9, 4], [9, 6], [9, 8], [9, 10], [9, 12], [9, 14], # plus pixel endcap
    ]
    dataset = TrackMLParticleTrackingDataset(root=data_dir,
                                             layer_pairs_plus=True, 
                                             pt_min= pt_min,
                                             volume_layer_ids=volume_layer_ids,
                                             n_events=n_events, 
                                             n_workers=n_workers, 
                                             tracking = generate_tracks,
                                             download_full_dataset=full_dataset)
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    if dataset_size > 2:
        split = int(np.floor(validation_split * dataset_size))
    else: 
        split = 1
    print(split)
    random_seed= 1001

    train_subset, val_subset = torch.utils.data.random_split(dataset, [dataset_size - split, split],
                                                             generator=torch.Generator().manual_seed(random_seed))
    print("train subset dim:", len(train_subset))
    print("validation subset dim", len(val_subset))
    dataloaders = {
        'train':  DataLoader(train_subset, batch_size=batch_size, shuffle=shuffle),
        'val':   DataLoader(val_subset, batch_size=batch_size, shuffle=shuffle)
        }
    print("train_dataloader dim:", len(dataloaders['train']))
    print("val dataloader dim:", len(dataloaders['val']))
    return dataloaders

In [4]:
loss_offset: float = 1.0

def compute_oc_loss(out, data, s_c=1., return_components=False):
    device = out.device
    pred_betas = torch.sigmoid(out[:,0])
    pred_cluster_space_coords = out[:,1:4]
    assert all(t.device == device for t in [
        pred_betas, pred_cluster_space_coords, data.y,
        data.batch,
        ])
    out_oc = calc_LV_Lbeta(
        pred_betas,
        pred_cluster_space_coords,
        data.y.long(),
        data.batch,
        return_components=return_components,
        qmin = 0.1
        )
    if return_components:
        return out_oc
    else:
        LV, Lbeta = out_oc
        return LV + Lbeta + loss_offset

In [5]:
def train(data_loader, model, epoch, optimizer,interval = 1):
    print('Training epoch', epoch)
    model.train()
    data = tqdm.tqdm(data_loader, total=len(data_loader))
    data.set_postfix({'loss': '?'})
    for i,inputs in enumerate(data):
        optimizer.zero_grad()
        result = model(inputs.x, inputs.batch)
        loss = compute_oc_loss(result,inputs)
        if i % interval == 0:
            print(f'loss={float(loss)}')
        loss.backward()
        optimizer.step()
        data.set_postfix({'loss': float(loss)})
    return model

In [6]:
def test(data_loader, model, epoch, generate_plots = False):
    with torch.no_grad():
        model.eval()
        loss = 0.
        data = tqdm.tqdm(data_loader, total=len(data_loader))
        for i,inputs in enumerate(data):
            result = model(inputs.x, inputs.batch)
            loss +=  compute_oc_loss(result,inputs)
            pred_betas = torch.sigmoid(result[:,0])
            pred_cluster_space_coords = result[:,1:4]
            if generate_plots:
                fig = go.Figure(get_plotly_truth(inputs,size = 2.75))
                fig.write_html("plots/truth_plot_epoch_"+str(epoch+1)+"_batch_"+str(i+1)+".html")
                pred_fig=go.Figure(get_plotly_clusterspace(inputs,pred_cluster_space_coords,size = 2.75))
                pred_fig.write_html("plots/predictions_plot_epoch_"+str(epoch+1)+"_batch_"+str(i+1)+".html")
        loss /= len(data_loader)
        print(f'Avg test loss: {loss}')

In [None]:
root = "/Users/gpradhan/Downloads/train_1"
noise_pt_min = 0.0
batch_size = 5
validation_split = 0.1
events = 50
model_input_dim = 6
model_output_dim = 4
data = fetch_dataloader(data_dir = root,
                        batch_size = batch_size,
                        validation_split=validation_split,
                        full_dataset = False,
                        n_events = events,
                        pt_min = noise_pt_min,
                        shuffle=True)
epochs = 20
train_loader,test_loader = data['train'],data['val']
model = GravnetModel(input_dim=6,output_dim=4)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-4)

for i_epoch in range(epochs):
    model = train(train_loader, 
                  model, 
                  i_epoch, 
                  optimizer,
                 interval = 5)
    test(test_loader, 
         model, 
         i_epoch,
        generate_plots = True) 

  0%|          | 0/9 [00:00<?, ?it/s, loss=?]

5
train subset dim: 45
validation subset dim 5
train_dataloader dim: 9
val dataloader dim: 1
Training epoch 0


 11%|█         | 1/9 [00:02<00:23,  2.99s/it, loss=2.14]

loss=2.1412487030029297


 67%|██████▋   | 6/9 [00:17<00:08,  2.89s/it, loss=2.14]

loss=2.1429831981658936


100%|██████████| 9/9 [00:25<00:00,  2.85s/it, loss=2.14]
100%|██████████| 1/1 [00:02<00:00,  2.73s/it]
  0%|          | 0/9 [00:00<?, ?it/s, loss=?]

Avg test loss: 2.135890007019043
Training epoch 1


 11%|█         | 1/9 [00:02<00:20,  2.61s/it, loss=2.14]

loss=2.1367688179016113


 67%|██████▋   | 6/9 [00:16<00:08,  2.84s/it, loss=2.15]

loss=2.1478421688079834


100%|██████████| 9/9 [00:25<00:00,  2.86s/it, loss=2.12]
100%|██████████| 1/1 [00:02<00:00,  2.72s/it]
  0%|          | 0/9 [00:00<?, ?it/s, loss=?]

Avg test loss: 2.15014910697937
Training epoch 2


 11%|█         | 1/9 [00:02<00:23,  2.95s/it, loss=2.13]

loss=2.1260814666748047


 67%|██████▋   | 6/9 [00:17<00:08,  2.89s/it, loss=2.13]

loss=2.127979278564453


100%|██████████| 9/9 [00:25<00:00,  2.87s/it, loss=2.13]
100%|██████████| 1/1 [00:02<00:00,  2.82s/it]
  0%|          | 0/9 [00:00<?, ?it/s, loss=?]

Avg test loss: 2.1249098777770996
Training epoch 3


 11%|█         | 1/9 [00:03<00:24,  3.01s/it, loss=2.12]

loss=2.119096279144287


 67%|██████▋   | 6/9 [00:17<00:08,  2.86s/it, loss=2.12]

loss=2.121321201324463


100%|██████████| 9/9 [00:26<00:00,  2.89s/it, loss=2.13]
100%|██████████| 1/1 [00:02<00:00,  2.75s/it]
  0%|          | 0/9 [00:00<?, ?it/s, loss=?]

Avg test loss: 2.112483024597168
Training epoch 4


 11%|█         | 1/9 [00:02<00:22,  2.81s/it, loss=2.12]

loss=2.1171886920928955


 67%|██████▋   | 6/9 [00:17<00:08,  2.93s/it, loss=2.11]

loss=2.113752841949463


100%|██████████| 9/9 [00:26<00:00,  2.92s/it, loss=2.11]
100%|██████████| 1/1 [00:02<00:00,  2.74s/it]
  0%|          | 0/9 [00:00<?, ?it/s, loss=?]

Avg test loss: 2.1321463584899902
Training epoch 5


 11%|█         | 1/9 [00:02<00:23,  2.98s/it, loss=2.1]

loss=2.102365255355835


 67%|██████▋   | 6/9 [00:17<00:08,  2.93s/it, loss=2.11]

loss=2.1095707416534424


100%|██████████| 9/9 [00:26<00:00,  2.90s/it, loss=2.11]
100%|██████████| 1/1 [00:02<00:00,  2.85s/it]
  0%|          | 0/9 [00:00<?, ?it/s, loss=?]

Avg test loss: 2.1477556228637695
Training epoch 6


 11%|█         | 1/9 [00:02<00:21,  2.69s/it, loss=2.09]

loss=2.0905303955078125


 67%|██████▋   | 6/9 [00:17<00:08,  2.99s/it, loss=2.1] 

loss=2.1017212867736816


100%|██████████| 9/9 [00:26<00:00,  2.91s/it, loss=2.11]
100%|██████████| 1/1 [00:02<00:00,  2.81s/it]
  0%|          | 0/9 [00:00<?, ?it/s, loss=?]

Avg test loss: 2.1337742805480957
Training epoch 7


 11%|█         | 1/9 [00:02<00:22,  2.84s/it, loss=2.09]

loss=2.090026378631592


 67%|██████▋   | 6/9 [00:17<00:08,  2.92s/it, loss=2.1] 

loss=2.0956039428710938


100%|██████████| 9/9 [00:26<00:00,  2.91s/it, loss=2.1] 
100%|██████████| 1/1 [00:02<00:00,  2.82s/it]
  0%|          | 0/9 [00:00<?, ?it/s, loss=?]

Avg test loss: 2.1030995845794678
Training epoch 8


 11%|█         | 1/9 [00:02<00:23,  2.98s/it, loss=2.09]

loss=2.0941643714904785


 67%|██████▋   | 6/9 [00:17<00:08,  2.98s/it, loss=2.07]

loss=2.072295665740967


100%|██████████| 9/9 [00:26<00:00,  2.91s/it, loss=2.08]
100%|██████████| 1/1 [00:02<00:00,  2.85s/it]
  0%|          | 0/9 [00:00<?, ?it/s, loss=?]

Avg test loss: 2.0721447467803955
Training epoch 9


 11%|█         | 1/9 [00:03<00:24,  3.07s/it, loss=2.04]

loss=2.043281078338623


 67%|██████▋   | 6/9 [00:17<00:08,  2.86s/it, loss=2.09]

loss=2.0894367694854736


100%|██████████| 9/9 [00:26<00:00,  2.90s/it, loss=2.08]
100%|██████████| 1/1 [00:02<00:00,  2.76s/it]
  0%|          | 0/9 [00:00<?, ?it/s, loss=?]

Avg test loss: 2.1052756309509277
Training epoch 10


 11%|█         | 1/9 [00:02<00:23,  2.92s/it, loss=2.08]

loss=2.0843417644500732


 67%|██████▋   | 6/9 [00:17<00:08,  2.93s/it, loss=2.03]

loss=2.0260260105133057


100%|██████████| 9/9 [00:26<00:00,  2.92s/it, loss=2.08]
100%|██████████| 1/1 [00:02<00:00,  2.74s/it]
  0%|          | 0/9 [00:00<?, ?it/s, loss=?]

Avg test loss: 2.0747928619384766
Training epoch 11


 11%|█         | 1/9 [00:02<00:23,  2.92s/it, loss=2.07]

loss=2.0695037841796875


 67%|██████▋   | 6/9 [00:17<00:08,  2.87s/it, loss=2.04]

loss=2.0409457683563232


100%|██████████| 9/9 [00:26<00:00,  2.93s/it, loss=2.03]
100%|██████████| 1/1 [00:02<00:00,  2.76s/it]
  0%|          | 0/9 [00:00<?, ?it/s, loss=?]

Avg test loss: 2.098924160003662
Training epoch 12


 11%|█         | 1/9 [00:02<00:21,  2.72s/it, loss=2.02]

loss=2.0245842933654785


 67%|██████▋   | 6/9 [00:17<00:08,  2.96s/it, loss=2.02]

loss=2.022092342376709


100%|██████████| 9/9 [00:26<00:00,  2.95s/it, loss=2.02]
100%|██████████| 1/1 [00:02<00:00,  2.84s/it]
  0%|          | 0/9 [00:00<?, ?it/s, loss=?]

Avg test loss: 2.0523979663848877
Training epoch 13


 11%|█         | 1/9 [00:03<00:25,  3.22s/it, loss=2.03]

loss=2.033064365386963


 67%|██████▋   | 6/9 [00:17<00:08,  2.99s/it, loss=2.03]

loss=2.025028705596924


 89%|████████▉ | 8/9 [00:23<00:02,  2.88s/it, loss=2.04]

In [None]:
#################### END ############################################################################################