In [1]:
from torch_scatter import scatter_mean
import torch.nn.functional as F
from torch_geometric.nn import global_max_pool


import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from glob import glob
import random

import os.path as osp
import pandas as pd
from plyfile import PlyData, PlyElement
import numpy as np
from numpy.typing import NDArray

from tqdm import tqdm
import pickle

data_dir = f"./data/FWF_flat/10.0X10.0"
osp.exists(data_dir)

True

In [2]:
# generate train_test split
all_names = list(glob(osp.join(data_dir,'ply','*.ply')))

# only train
train_names = [n for n in all_names if 'Area_06' not in osp.basename(n)]
train_names

## split train/test
# train_ratio = 0.95
# N = len(all_names)
# k = int(train_ratio * N)
# train_i = random.sample(range(N), k) 

# train_mask = np.zeros(len(all_names)).astype(bool)
# train_mask[train_i] = True


# all_names = np.array(all_names)
# train_names = all_names[train_mask].tolist()
# test_names = all_names[~train_mask].tolist()


['./data/FWF_flat/10.0X10.0/ply/Area_01_X000-Y007.ply',
 './data/FWF_flat/10.0X10.0/ply/Area_01_X001-Y005.ply',
 './data/FWF_flat/10.0X10.0/ply/Area_01_X001-Y006.ply',
 './data/FWF_flat/10.0X10.0/ply/Area_01_X001-Y007.ply',
 './data/FWF_flat/10.0X10.0/ply/Area_01_X002-Y002.ply',
 './data/FWF_flat/10.0X10.0/ply/Area_01_X002-Y003.ply',
 './data/FWF_flat/10.0X10.0/ply/Area_01_X002-Y004.ply',
 './data/FWF_flat/10.0X10.0/ply/Area_01_X002-Y005.ply',
 './data/FWF_flat/10.0X10.0/ply/Area_01_X002-Y006.ply',
 './data/FWF_flat/10.0X10.0/ply/Area_01_X002-Y007.ply',
 './data/FWF_flat/10.0X10.0/ply/Area_01_X002-Y008.ply',
 './data/FWF_flat/10.0X10.0/ply/Area_01_X003-Y000.ply',
 './data/FWF_flat/10.0X10.0/ply/Area_01_X003-Y001.ply',
 './data/FWF_flat/10.0X10.0/ply/Area_01_X003-Y002.ply',
 './data/FWF_flat/10.0X10.0/ply/Area_01_X003-Y003.ply',
 './data/FWF_flat/10.0X10.0/ply/Area_01_X003-Y004.ply',
 './data/FWF_flat/10.0X10.0/ply/Area_01_X003-Y005.ply',
 './data/FWF_flat/10.0X10.0/ply/Area_01_X003-Y00

In [3]:

def forwardstar_to_idx(first_edg:NDArray, adj_verts:NDArray)->NDArray:
    sources = []
    targets = []

    V = len(first_edg) - 1
    for u in range(V):
        start, end = first_edg[u], first_edg[u+1]
        src = np.full(end - start, u)
        tgt = adj_verts[start:end]
        sources.append(src)
        targets.append(tgt)
    sources = np.concat(sources)
    targets = np.concat(targets)

    return np.concat([sources[None,:], targets[None,:]],axis=0)



In [4]:
from typing import List, Tuple
from torch_scatter import scatter_add

def make_directed(edge_index: torch.Tensor) -> torch.Tensor:
    rev = edge_index[[1, 0], :]
    return torch.cat([edge_index, rev], dim=1)

def labels_pt2reg(y:torch.Tensor, assignment:torch.Tensor, num_classes:int,dim_size:int|None=None)->torch.Tensor:
    assert torch.all(y>=0)
    y_oh = torch.eye(num_classes)
    y_sum = scatter_add(y_oh[y], assignment, dim=0,dim_size=dim_size)
    y_reg = torch.argmax(y_sum,dim=1)
    return y_reg

class ChunkedDataset(Dataset):
    def __init__(self,
                 data_dir:str,
                 names:List[str],
                 num_classes:int,
                 ):
        
        self.data_dir = data_dir
        self.data = list()
        self.feat_names = list()
        self.names = names
        self.num_classes = num_classes

        for pcd_fp in tqdm(glob(osp.join(data_dir,'ply','*.ply'))):
            fn = osp.basename(pcd_fp).replace('.ply','')
            if fn not in self.names:
                continue
            graph_fp = osp.join(osp.split(pcd_fp)[0],'..','graph',f"{fn}.pkl")

            # load data
            pcd = pd.DataFrame(PlyData.read(pcd_fp).elements[0].data)
            pos = pcd[['x','y','z']].to_numpy(dtype=np.float32)
            feat_names = [n for n in pcd.columns if n not in ['x','y','z','labels']]
            if not len(self.feat_names): self.feat_names = feat_names
            feats = pcd[feat_names].to_numpy(dtype=np.float32)
            labels = pcd['labels'].to_numpy(dtype=np.int32)

            
            with open(graph_fp,'rb') as f:
                graph = pickle.load(f)

            
            point_feats = graph['point_feats']

            # edges
            efs = graph['edges_forwardstar']                       
            edge_index = forwardstar_to_idx(efs[0],efs[1])
            edge_index = torch.as_tensor(edge_index,dtype=torch.int64)
            edge_index = make_directed(edge_index)

            # majority vote label for each superpoint
            labels = torch.as_tensor(labels,dtype=torch.int64)
            superpoint_idx = torch.as_tensor(graph['superpoint_idx'],dtype=torch.int64)
            labels = labels_pt2reg(labels,assignment=superpoint_idx, num_classes=self.num_classes)

            self.data.append(dict(
                fn = fn,
                pos = torch.as_tensor(pos),
                feats = torch.as_tensor(feats),
                labels = labels,
                superpoint_idx = superpoint_idx,
                edge_index = edge_index,
                point_feats = torch.as_tensor(point_feats),
            ))

    def __getitem__(self,idx):
        return self.data[idx]
    
    def __len__(self):
        return len(self.data)

        

            


train_ds = ChunkedDataset(data_dir,[osp.basename(n).replace('.ply','') for n in train_names],num_classes=89) # type: ignore
# test_ds = ChunkedDataset(data_dir,[osp.basename(n).replace('.ply','') for n in test_names],num_classes=89) # type: ignore


100%|██████████| 288/288 [00:25<00:00, 11.30it/s]


In [5]:
class PointNetEncoder(nn.Module):
    def __init__(self,
                 in_dim:int,
                 latent_dim:int,
                 hidden: int = 64,
                 dropout: float=0.1
                 ):
        super().__init__()
        
        self.lin1   = nn.Linear(in_dim, hidden)
        self.bn1    = nn.LayerNorm(hidden)
        

        self.lin2   = nn.Linear(hidden, latent_dim)
        self.bn2    = nn.LayerNorm(latent_dim)

        self.drop   = nn.Dropout(dropout)

    def forward(self, 
                pos: torch.Tensor,
                point_feats: torch.Tensor,
                super_idx: torch.Tensor                
                ):

        # prepare

        x = torch.cat([pos, point_feats], dim=-1)
        
        # layer 1
        x = self.drop(F.relu(self.bn1(self.lin1(x))))
        # layer 2
        x = self.drop(F.relu(self.bn2(self.lin2(x))))

        # pool
        super_embedding = global_max_pool(x, super_idx)

        return super_embedding
    
class SPConv(nn.Module):
    def __init__(self,
                 d:int,
                 d_e:int):
        super().__init__()
        
        # edge encoder
        self.phi_e = nn.Sequential(
            nn.Linear(3, d_e), nn.ReLU(),
            nn.Linear(d_e, d_e), nn.ReLU()
        )

        # message MLP
        self.phi_m = nn.Sequential(
            nn.Linear(2*d + d_e, d), nn.ReLU(),
            nn.Linear(d,d)
        )

        # update MLP
        self.phi_u = nn.Sequential(
            nn.Linear(2*d,d) , nn.ReLU()
        )

    def forward(self, 
                z,          # [R,d]
                centroids,  # [R,3]
                edge_index  # [2,E]
                ):
        src, dst = edge_index 

        # edge feats
        delta = centroids[dst] - centroids[src] # [E,3]
        e = self.phi_e(delta)                   # [E, d_3]

        # messages
        m = self.phi_m(torch.cat([z[src], z[dst], e], dim=-1))  # [E,d]

        # aggregation
        M = scatter_mean(m, dst, dim=0, dim_size=z.size(0))     # [R,d]

        # update
        out = self.phi_u(torch.cat([z,M], dim=-1))  # [R,d]
        return out





class SimpleModel(nn.Module):
    def __init__(self,
                 in_dim:int,
                 num_classes:int,

                 pn_latent_dim:int = 128,
                 pn_hidden: int = 128,
                 pn_dropout: float=0.1,

                 super_hidden:int = 128,
                 super_latent_dim:int = 32,
                 super_dropout: float = 0.1,

                 ):
        
        super().__init__()
        self.encoder = PointNetEncoder(in_dim,pn_latent_dim,pn_hidden,pn_dropout)

        self.lin1   = nn.Linear(pn_latent_dim,super_hidden)
        self.bn1    = nn.LayerNorm(super_hidden)

        self.lin2   = nn.Linear(super_hidden, super_latent_dim)
        self.bn2    = nn.LayerNorm(super_latent_dim)

        self.dropout = nn.Dropout(super_dropout)

        self.classifier = nn.Linear(super_latent_dim, num_classes)

    def forward(self,
                pos: torch.Tensor,
                point_feats: torch.Tensor,
                super_idx: torch.Tensor                
                ):
        # prepare inputs
        super_centroids = scatter_mean(pos, super_idx, dim=0)
        pos_local = pos - super_centroids[super_idx]

        # run model
        x = self.encoder(pos_local, point_feats, super_idx)

        x = self.dropout(F.relu(self.bn1(self.lin1(x))))
        z = self.dropout(F.relu(self.bn2(self.lin2(x))))

        return self.classifier(z), z


class SPConvModel(nn.Module):
    def __init__(
            self,
            in_dim:int,
            num_classes:int,

            pn_hidden: int = 128,
            pn_dropout: float=0.1,

            super_hidden:int=128,
            super_dropout:float=0.2,
                
            edge_dim:int = 64,
            super_num_conv:int=3,
                 ):
        super().__init__()
        self.encoder = PointNetEncoder(in_dim,super_hidden,pn_hidden,pn_dropout)

        self.conv_blocks = nn.ModuleList([
            SPConv(super_hidden, edge_dim) for _ in range(super_num_conv)
        ])
        self.ln_blocks = nn.ModuleList([
            nn.LayerNorm(super_hidden) for _ in range(super_num_conv)
        ])
        
        self.classifier = nn.Linear(super_hidden,num_classes)
        self.dropout = nn.Dropout(super_dropout)

    def forward(self,
                pos: torch.Tensor,
                point_feats: torch.Tensor,
                super_idx: torch.Tensor,
                edge_idx: torch.Tensor,                
                ):
        # prepare inputs
        super_centroids = scatter_mean(pos, super_idx, dim=0)
        pos_local = pos - super_centroids[super_idx]

        # run model
        x = self.encoder(pos_local, point_feats, super_idx)
        z = x
        for conv, ln in zip(self.conv_blocks, self.ln_blocks):
            x = conv(x, super_centroids, edge_idx)
            x = ln(x + z)
            x = F.relu(x)
            x = self.dropout(x)
            z = x

        x = self.classifier(x)

        return x, z







    

In [6]:
# contrastive loss spconv

config_max_epochs = 360
config_device = 'cuda'
config_margin = 1.0
config_lambda_c = 0.1

train_dl = DataLoader(train_ds,batch_size=1,shuffle=True, collate_fn=lambda x: x[0])
# test_dl = DataLoader(test_ds,batch_size=1, collate_fn=lambda x: x[0])

model = SimpleModel(
    in_dim = 3+12,
    num_classes=89,
    pn_hidden=64,
    super_hidden=64,
    pn_dropout=0.1,
    super_dropout=0.2,

).to(device=config_device)

optim = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

for epoch in range(1,config_max_epochs+1):
    model.train()
    total_loss = 0.
    for d in tqdm(train_dl,total=len(train_ds), desc=f"Epoch {epoch:03d}"):
        pos         = d['pos'].to(device=config_device)
        point_feats = d['point_feats'].to(device=config_device)
        labels      = d['labels'].to(device=config_device)
        super_idx   = d['superpoint_idx'].to(device=config_device)


        S = super_idx.amax() + 1 
        logits, z = model(pos, point_feats, super_idx)
        L_class = F.cross_entropy(logits, labels)


        # total loss
        loss = L_class 
        optim.zero_grad()
        loss.backward()
        optim.step()
        total_loss += loss.item()

    print(f"Epoch {epoch:02d}  average loss: {total_loss/len(train_dl):.4f}")

torch.save(model, './experiments/simple_model.pth')

Epoch 001: 100%|██████████| 283/283 [00:14<00:00, 19.87it/s]


Epoch 01  average loss: 2.7652


Epoch 002: 100%|██████████| 283/283 [00:13<00:00, 21.47it/s]


Epoch 02  average loss: 2.0913


Epoch 003: 100%|██████████| 283/283 [00:12<00:00, 21.82it/s]


Epoch 03  average loss: 1.8753


Epoch 004: 100%|██████████| 283/283 [00:11<00:00, 24.28it/s]


Epoch 04  average loss: 1.7699


Epoch 005: 100%|██████████| 283/283 [00:11<00:00, 24.22it/s]


Epoch 05  average loss: 1.6879


Epoch 006: 100%|██████████| 283/283 [00:11<00:00, 24.61it/s]


Epoch 06  average loss: 1.6542


Epoch 007: 100%|██████████| 283/283 [00:11<00:00, 24.20it/s]


Epoch 07  average loss: 1.6058


Epoch 008: 100%|██████████| 283/283 [00:11<00:00, 24.66it/s]


Epoch 08  average loss: 1.5666


Epoch 009: 100%|██████████| 283/283 [00:11<00:00, 24.88it/s]


Epoch 09  average loss: 1.5395


Epoch 010: 100%|██████████| 283/283 [00:11<00:00, 25.13it/s]


Epoch 10  average loss: 1.5046


Epoch 011: 100%|██████████| 283/283 [00:11<00:00, 25.15it/s]


Epoch 11  average loss: 1.5080


Epoch 012: 100%|██████████| 283/283 [00:11<00:00, 25.00it/s]


Epoch 12  average loss: 1.4823


Epoch 013: 100%|██████████| 283/283 [00:11<00:00, 24.18it/s]


Epoch 13  average loss: 1.4592


Epoch 014: 100%|██████████| 283/283 [00:11<00:00, 24.69it/s]


Epoch 14  average loss: 1.4493


Epoch 015: 100%|██████████| 283/283 [00:11<00:00, 24.42it/s]


Epoch 15  average loss: 1.4255


Epoch 016: 100%|██████████| 283/283 [00:11<00:00, 24.16it/s]


Epoch 16  average loss: 1.4190


Epoch 017: 100%|██████████| 283/283 [00:11<00:00, 24.75it/s]


Epoch 17  average loss: 1.3990


Epoch 018: 100%|██████████| 283/283 [00:11<00:00, 24.89it/s]


Epoch 18  average loss: 1.3853


Epoch 019: 100%|██████████| 283/283 [00:11<00:00, 24.87it/s]


Epoch 19  average loss: 1.3703


Epoch 020: 100%|██████████| 283/283 [00:11<00:00, 25.14it/s]


Epoch 20  average loss: 1.3667


Epoch 021: 100%|██████████| 283/283 [00:11<00:00, 25.10it/s]


Epoch 21  average loss: 1.3483


Epoch 022: 100%|██████████| 283/283 [00:11<00:00, 24.60it/s]


Epoch 22  average loss: 1.3324


Epoch 023: 100%|██████████| 283/283 [00:11<00:00, 24.63it/s]


Epoch 23  average loss: 1.3396


Epoch 024: 100%|██████████| 283/283 [00:11<00:00, 24.69it/s]


Epoch 24  average loss: 1.3108


Epoch 025: 100%|██████████| 283/283 [00:11<00:00, 24.92it/s]


Epoch 25  average loss: 1.3219


Epoch 026: 100%|██████████| 283/283 [00:11<00:00, 24.59it/s]


Epoch 26  average loss: 1.3161


Epoch 027: 100%|██████████| 283/283 [00:11<00:00, 25.14it/s]


Epoch 27  average loss: 1.3002


Epoch 028: 100%|██████████| 283/283 [00:11<00:00, 25.11it/s]


Epoch 28  average loss: 1.2864


Epoch 029: 100%|██████████| 283/283 [00:11<00:00, 24.60it/s]


Epoch 29  average loss: 1.2833


Epoch 030: 100%|██████████| 283/283 [00:11<00:00, 24.09it/s]


Epoch 30  average loss: 1.2743


Epoch 031: 100%|██████████| 283/283 [00:11<00:00, 24.54it/s]


Epoch 31  average loss: 1.2681


Epoch 032: 100%|██████████| 283/283 [00:11<00:00, 24.32it/s]


Epoch 32  average loss: 1.2550


Epoch 033: 100%|██████████| 283/283 [00:11<00:00, 24.77it/s]


Epoch 33  average loss: 1.2520


Epoch 034: 100%|██████████| 283/283 [00:11<00:00, 24.25it/s]


Epoch 34  average loss: 1.2526


Epoch 035: 100%|██████████| 283/283 [00:11<00:00, 24.51it/s]


Epoch 35  average loss: 1.2532


Epoch 036: 100%|██████████| 283/283 [00:11<00:00, 24.41it/s]


Epoch 36  average loss: 1.2255


Epoch 037: 100%|██████████| 283/283 [00:11<00:00, 24.57it/s]


Epoch 37  average loss: 1.2393


Epoch 038: 100%|██████████| 283/283 [00:11<00:00, 24.61it/s]


Epoch 38  average loss: 1.2185


Epoch 039: 100%|██████████| 283/283 [00:11<00:00, 25.16it/s]


Epoch 39  average loss: 1.2124


Epoch 040: 100%|██████████| 283/283 [00:11<00:00, 25.04it/s]


Epoch 40  average loss: 1.2111


Epoch 041: 100%|██████████| 283/283 [00:11<00:00, 24.35it/s]


Epoch 41  average loss: 1.2075


Epoch 042: 100%|██████████| 283/283 [00:11<00:00, 24.52it/s]


Epoch 42  average loss: 1.2058


Epoch 043: 100%|██████████| 283/283 [00:11<00:00, 24.51it/s]


Epoch 43  average loss: 1.1925


Epoch 044: 100%|██████████| 283/283 [00:11<00:00, 24.79it/s]


Epoch 44  average loss: 1.1936


Epoch 045: 100%|██████████| 283/283 [00:11<00:00, 24.40it/s]


Epoch 45  average loss: 1.1860


Epoch 046: 100%|██████████| 283/283 [00:11<00:00, 24.76it/s]


Epoch 46  average loss: 1.1694


Epoch 047: 100%|██████████| 283/283 [00:11<00:00, 25.07it/s]


Epoch 47  average loss: 1.1850


Epoch 048: 100%|██████████| 283/283 [00:11<00:00, 24.31it/s]


Epoch 48  average loss: 1.1608


Epoch 049: 100%|██████████| 283/283 [00:11<00:00, 24.52it/s]


Epoch 49  average loss: 1.1568


Epoch 050: 100%|██████████| 283/283 [00:11<00:00, 24.56it/s]


Epoch 50  average loss: 1.1606


Epoch 051: 100%|██████████| 283/283 [00:11<00:00, 24.44it/s]


Epoch 51  average loss: 1.1527


Epoch 052: 100%|██████████| 283/283 [00:11<00:00, 24.75it/s]


Epoch 52  average loss: 1.1449


Epoch 053: 100%|██████████| 283/283 [00:11<00:00, 25.08it/s]


Epoch 53  average loss: 1.1417


Epoch 054: 100%|██████████| 283/283 [00:11<00:00, 24.47it/s]


Epoch 54  average loss: 1.1378


Epoch 055: 100%|██████████| 283/283 [00:11<00:00, 24.54it/s]


Epoch 55  average loss: 1.1293


Epoch 056: 100%|██████████| 283/283 [00:11<00:00, 24.72it/s]


Epoch 56  average loss: 1.1336


Epoch 057: 100%|██████████| 283/283 [00:11<00:00, 24.74it/s]


Epoch 57  average loss: 1.1267


Epoch 058: 100%|██████████| 283/283 [00:11<00:00, 24.59it/s]


Epoch 58  average loss: 1.1176


Epoch 059: 100%|██████████| 283/283 [00:11<00:00, 24.67it/s]


Epoch 59  average loss: 1.1128


Epoch 060: 100%|██████████| 283/283 [00:11<00:00, 24.62it/s]


Epoch 60  average loss: 1.1110


Epoch 061: 100%|██████████| 283/283 [00:11<00:00, 24.88it/s]


Epoch 61  average loss: 1.1107


Epoch 062: 100%|██████████| 283/283 [00:11<00:00, 24.55it/s]


Epoch 62  average loss: 1.1023


Epoch 063: 100%|██████████| 283/283 [00:11<00:00, 24.06it/s]


Epoch 63  average loss: 1.0846


Epoch 064: 100%|██████████| 283/283 [00:11<00:00, 24.13it/s]


Epoch 64  average loss: 1.1075


Epoch 065: 100%|██████████| 283/283 [00:11<00:00, 24.56it/s]


Epoch 65  average loss: 1.0984


Epoch 066: 100%|██████████| 283/283 [00:11<00:00, 24.46it/s]


Epoch 66  average loss: 1.0923


Epoch 067: 100%|██████████| 283/283 [00:11<00:00, 24.94it/s]


Epoch 67  average loss: 1.0927


Epoch 068: 100%|██████████| 283/283 [00:11<00:00, 23.97it/s]


Epoch 68  average loss: 1.0836


Epoch 069: 100%|██████████| 283/283 [00:11<00:00, 24.16it/s]


Epoch 69  average loss: 1.0806


Epoch 070: 100%|██████████| 283/283 [00:11<00:00, 24.32it/s]


Epoch 70  average loss: 1.0749


Epoch 071: 100%|██████████| 283/283 [00:11<00:00, 24.25it/s]


Epoch 71  average loss: 1.0752


Epoch 072: 100%|██████████| 283/283 [00:11<00:00, 24.27it/s]


Epoch 72  average loss: 1.0654


Epoch 073: 100%|██████████| 283/283 [00:11<00:00, 23.97it/s]


Epoch 73  average loss: 1.0600


Epoch 074: 100%|██████████| 283/283 [00:11<00:00, 24.33it/s]


Epoch 74  average loss: 1.0623


Epoch 075: 100%|██████████| 283/283 [00:11<00:00, 24.80it/s]


Epoch 75  average loss: 1.0621


Epoch 076: 100%|██████████| 283/283 [00:11<00:00, 25.15it/s]


Epoch 76  average loss: 1.0520


Epoch 077: 100%|██████████| 283/283 [00:11<00:00, 24.74it/s]


Epoch 77  average loss: 1.0473


Epoch 078: 100%|██████████| 283/283 [00:11<00:00, 24.44it/s]


Epoch 78  average loss: 1.0476


Epoch 079: 100%|██████████| 283/283 [00:11<00:00, 24.89it/s]


Epoch 79  average loss: 1.0393


Epoch 080: 100%|██████████| 283/283 [00:11<00:00, 24.63it/s]


Epoch 80  average loss: 1.0392


Epoch 081: 100%|██████████| 283/283 [00:11<00:00, 24.81it/s]


Epoch 81  average loss: 1.0332


Epoch 082: 100%|██████████| 283/283 [00:11<00:00, 24.29it/s]


Epoch 82  average loss: 1.0450


Epoch 083: 100%|██████████| 283/283 [00:11<00:00, 24.61it/s]


Epoch 83  average loss: 1.0280


Epoch 084: 100%|██████████| 283/283 [00:11<00:00, 24.37it/s]


Epoch 84  average loss: 1.0348


Epoch 085: 100%|██████████| 283/283 [00:11<00:00, 24.64it/s]


Epoch 85  average loss: 1.0226


Epoch 086: 100%|██████████| 283/283 [00:11<00:00, 24.73it/s]


Epoch 86  average loss: 1.0242


Epoch 087: 100%|██████████| 283/283 [00:11<00:00, 24.41it/s]


Epoch 87  average loss: 1.0221


Epoch 088: 100%|██████████| 283/283 [00:11<00:00, 24.65it/s]


Epoch 88  average loss: 1.0165


Epoch 089: 100%|██████████| 283/283 [00:11<00:00, 24.43it/s]


Epoch 89  average loss: 1.0248


Epoch 090: 100%|██████████| 283/283 [00:11<00:00, 24.45it/s]


Epoch 90  average loss: 1.0147


Epoch 091: 100%|██████████| 283/283 [00:11<00:00, 24.63it/s]


Epoch 91  average loss: 1.0140


Epoch 092: 100%|██████████| 283/283 [00:11<00:00, 24.81it/s]


Epoch 92  average loss: 1.0097


Epoch 093: 100%|██████████| 283/283 [00:11<00:00, 24.98it/s]


Epoch 93  average loss: 1.0037


Epoch 094: 100%|██████████| 283/283 [00:11<00:00, 24.87it/s]


Epoch 94  average loss: 1.0276


Epoch 095: 100%|██████████| 283/283 [00:11<00:00, 25.01it/s]


Epoch 95  average loss: 1.0043


Epoch 096: 100%|██████████| 283/283 [00:11<00:00, 24.73it/s]


Epoch 96  average loss: 0.9983


Epoch 097: 100%|██████████| 283/283 [00:11<00:00, 24.70it/s]


Epoch 97  average loss: 0.9931


Epoch 098: 100%|██████████| 283/283 [00:11<00:00, 24.71it/s]


Epoch 98  average loss: 0.9920


Epoch 099: 100%|██████████| 283/283 [00:11<00:00, 24.60it/s]


Epoch 99  average loss: 0.9851


Epoch 100: 100%|██████████| 283/283 [00:11<00:00, 23.80it/s]


Epoch 100  average loss: 0.9901


Epoch 101: 100%|██████████| 283/283 [00:11<00:00, 23.89it/s]


Epoch 101  average loss: 0.9929


Epoch 102: 100%|██████████| 283/283 [00:11<00:00, 24.06it/s]


Epoch 102  average loss: 0.9903


Epoch 103: 100%|██████████| 283/283 [00:11<00:00, 24.29it/s]


Epoch 103  average loss: 0.9774


Epoch 104: 100%|██████████| 283/283 [00:11<00:00, 24.33it/s]


Epoch 104  average loss: 0.9871


Epoch 105: 100%|██████████| 283/283 [00:12<00:00, 23.51it/s]


Epoch 105  average loss: 0.9823


Epoch 106: 100%|██████████| 283/283 [00:11<00:00, 23.84it/s]


Epoch 106  average loss: 0.9795


Epoch 107: 100%|██████████| 283/283 [00:11<00:00, 24.56it/s]


Epoch 107  average loss: 0.9789


Epoch 108: 100%|██████████| 283/283 [00:11<00:00, 23.81it/s]


Epoch 108  average loss: 0.9666


Epoch 109: 100%|██████████| 283/283 [00:11<00:00, 24.58it/s]


Epoch 109  average loss: 0.9643


Epoch 110: 100%|██████████| 283/283 [00:11<00:00, 25.17it/s]


Epoch 110  average loss: 0.9664


Epoch 111: 100%|██████████| 283/283 [00:11<00:00, 24.92it/s]


Epoch 111  average loss: 0.9722


Epoch 112: 100%|██████████| 283/283 [00:11<00:00, 24.25it/s]


Epoch 112  average loss: 0.9686


Epoch 113: 100%|██████████| 283/283 [00:11<00:00, 24.27it/s]


Epoch 113  average loss: 0.9579


Epoch 114: 100%|██████████| 283/283 [00:11<00:00, 24.72it/s]


Epoch 114  average loss: 0.9730


Epoch 115: 100%|██████████| 283/283 [00:11<00:00, 24.51it/s]


Epoch 115  average loss: 0.9575


Epoch 116: 100%|██████████| 283/283 [00:11<00:00, 24.80it/s]


Epoch 116  average loss: 0.9589


Epoch 117: 100%|██████████| 283/283 [00:11<00:00, 24.07it/s]


Epoch 117  average loss: 0.9544


Epoch 118: 100%|██████████| 283/283 [00:11<00:00, 24.95it/s]


Epoch 118  average loss: 0.9513


Epoch 119: 100%|██████████| 283/283 [00:11<00:00, 25.08it/s]


Epoch 119  average loss: 0.9542


Epoch 120: 100%|██████████| 283/283 [00:11<00:00, 25.02it/s]


Epoch 120  average loss: 0.9610


Epoch 121: 100%|██████████| 283/283 [00:11<00:00, 25.15it/s]


Epoch 121  average loss: 0.9460


Epoch 122: 100%|██████████| 283/283 [00:11<00:00, 24.76it/s]


Epoch 122  average loss: 0.9466


Epoch 123: 100%|██████████| 283/283 [00:11<00:00, 24.21it/s]


Epoch 123  average loss: 0.9396


Epoch 124: 100%|██████████| 283/283 [00:11<00:00, 24.36it/s]


Epoch 124  average loss: 0.9478


Epoch 125: 100%|██████████| 283/283 [00:11<00:00, 24.95it/s]


Epoch 125  average loss: 0.9440


Epoch 126: 100%|██████████| 283/283 [00:11<00:00, 24.14it/s]


Epoch 126  average loss: 0.9497


Epoch 127: 100%|██████████| 283/283 [00:11<00:00, 24.60it/s]


Epoch 127  average loss: 0.9399


Epoch 128: 100%|██████████| 283/283 [00:11<00:00, 24.27it/s]


Epoch 128  average loss: 0.9298


Epoch 129: 100%|██████████| 283/283 [00:11<00:00, 23.86it/s]


Epoch 129  average loss: 0.9359


Epoch 130: 100%|██████████| 283/283 [00:11<00:00, 24.99it/s]


Epoch 130  average loss: 0.9397


Epoch 131: 100%|██████████| 283/283 [00:11<00:00, 24.45it/s]


Epoch 131  average loss: 0.9332


Epoch 132: 100%|██████████| 283/283 [00:11<00:00, 24.90it/s]


Epoch 132  average loss: 0.9363


Epoch 133: 100%|██████████| 283/283 [00:11<00:00, 24.52it/s]


Epoch 133  average loss: 0.9210


Epoch 134: 100%|██████████| 283/283 [00:11<00:00, 23.76it/s]


Epoch 134  average loss: 0.9410


Epoch 135: 100%|██████████| 283/283 [00:11<00:00, 24.57it/s]


Epoch 135  average loss: 0.9258


Epoch 136: 100%|██████████| 283/283 [00:11<00:00, 24.16it/s]


Epoch 136  average loss: 0.9259


Epoch 137: 100%|██████████| 283/283 [00:11<00:00, 24.25it/s]


Epoch 137  average loss: 0.9195


Epoch 138: 100%|██████████| 283/283 [00:11<00:00, 24.11it/s]


Epoch 138  average loss: 0.9293


Epoch 139: 100%|██████████| 283/283 [00:11<00:00, 23.89it/s]


Epoch 139  average loss: 0.9203


Epoch 140: 100%|██████████| 283/283 [00:11<00:00, 24.45it/s]


Epoch 140  average loss: 0.9126


Epoch 141: 100%|██████████| 283/283 [00:11<00:00, 24.15it/s]


Epoch 141  average loss: 0.9089


Epoch 142: 100%|██████████| 283/283 [00:11<00:00, 24.72it/s]


Epoch 142  average loss: 0.9085


Epoch 143: 100%|██████████| 283/283 [00:11<00:00, 23.91it/s]


Epoch 143  average loss: 0.9134


Epoch 144: 100%|██████████| 283/283 [00:11<00:00, 24.18it/s]


Epoch 144  average loss: 0.9173


Epoch 145: 100%|██████████| 283/283 [00:11<00:00, 24.06it/s]


Epoch 145  average loss: 0.9100


Epoch 146: 100%|██████████| 283/283 [00:11<00:00, 24.63it/s]


Epoch 146  average loss: 0.9235


Epoch 147: 100%|██████████| 283/283 [00:11<00:00, 24.22it/s]


Epoch 147  average loss: 0.9106


Epoch 148: 100%|██████████| 283/283 [00:11<00:00, 25.10it/s]


Epoch 148  average loss: 0.9076


Epoch 149: 100%|██████████| 283/283 [00:11<00:00, 25.02it/s]


Epoch 149  average loss: 0.9021


Epoch 150: 100%|██████████| 283/283 [00:11<00:00, 25.14it/s]


Epoch 150  average loss: 0.8961


Epoch 151: 100%|██████████| 283/283 [00:11<00:00, 24.81it/s]


Epoch 151  average loss: 0.8987


Epoch 152: 100%|██████████| 283/283 [00:11<00:00, 24.01it/s]


Epoch 152  average loss: 0.9085


Epoch 153: 100%|██████████| 283/283 [00:11<00:00, 23.93it/s]


Epoch 153  average loss: 0.8992


Epoch 154: 100%|██████████| 283/283 [00:11<00:00, 24.13it/s]


Epoch 154  average loss: 0.9216


Epoch 155: 100%|██████████| 283/283 [00:11<00:00, 25.16it/s]


Epoch 155  average loss: 0.9006


Epoch 156: 100%|██████████| 283/283 [00:11<00:00, 23.85it/s]


Epoch 156  average loss: 0.8956


Epoch 157: 100%|██████████| 283/283 [00:11<00:00, 24.47it/s]


Epoch 157  average loss: 0.8920


Epoch 158: 100%|██████████| 283/283 [00:11<00:00, 24.24it/s]


Epoch 158  average loss: 0.8961


Epoch 159: 100%|██████████| 283/283 [00:11<00:00, 24.15it/s]


Epoch 159  average loss: 0.8966


Epoch 160: 100%|██████████| 283/283 [00:11<00:00, 24.07it/s]


Epoch 160  average loss: 0.8905


Epoch 161: 100%|██████████| 283/283 [00:11<00:00, 23.88it/s]


Epoch 161  average loss: 0.8971


Epoch 162: 100%|██████████| 283/283 [00:11<00:00, 24.61it/s]


Epoch 162  average loss: 0.8984


Epoch 163: 100%|██████████| 283/283 [00:11<00:00, 24.39it/s]


Epoch 163  average loss: 0.8942


Epoch 164: 100%|██████████| 283/283 [00:12<00:00, 23.40it/s]


Epoch 164  average loss: 0.8885


Epoch 165: 100%|██████████| 283/283 [00:11<00:00, 24.43it/s]


Epoch 165  average loss: 0.8836


Epoch 166: 100%|██████████| 283/283 [00:11<00:00, 25.08it/s]


Epoch 166  average loss: 0.9019


Epoch 167: 100%|██████████| 283/283 [00:11<00:00, 23.65it/s]


Epoch 167  average loss: 0.8906


Epoch 168: 100%|██████████| 283/283 [00:11<00:00, 24.03it/s]


Epoch 168  average loss: 0.8866


Epoch 169: 100%|██████████| 283/283 [00:11<00:00, 24.61it/s]


Epoch 169  average loss: 0.8767


Epoch 170: 100%|██████████| 283/283 [00:11<00:00, 23.90it/s]


Epoch 170  average loss: 0.8804


Epoch 171: 100%|██████████| 283/283 [00:11<00:00, 23.97it/s]


Epoch 171  average loss: 0.8919


Epoch 172: 100%|██████████| 283/283 [00:11<00:00, 23.89it/s]


Epoch 172  average loss: 0.8844


Epoch 173: 100%|██████████| 283/283 [00:12<00:00, 23.56it/s]


Epoch 173  average loss: 0.8864


Epoch 174: 100%|██████████| 283/283 [00:11<00:00, 24.60it/s]


Epoch 174  average loss: 0.8722


Epoch 175: 100%|██████████| 283/283 [00:11<00:00, 24.24it/s]


Epoch 175  average loss: 0.8762


Epoch 176: 100%|██████████| 283/283 [00:11<00:00, 24.17it/s]


Epoch 176  average loss: 0.8712


Epoch 177: 100%|██████████| 283/283 [00:12<00:00, 23.53it/s]


Epoch 177  average loss: 0.8784


Epoch 178: 100%|██████████| 283/283 [00:11<00:00, 24.21it/s]


Epoch 178  average loss: 0.8754


Epoch 179: 100%|██████████| 283/283 [00:11<00:00, 24.31it/s]


Epoch 179  average loss: 0.8751


Epoch 180: 100%|██████████| 283/283 [00:11<00:00, 24.27it/s]


Epoch 180  average loss: 0.8682


Epoch 181: 100%|██████████| 283/283 [00:11<00:00, 25.02it/s]


Epoch 181  average loss: 0.8661


Epoch 182: 100%|██████████| 283/283 [00:11<00:00, 24.10it/s]


Epoch 182  average loss: 0.8724


Epoch 183: 100%|██████████| 283/283 [00:11<00:00, 24.34it/s]


Epoch 183  average loss: 0.8648


Epoch 184: 100%|██████████| 283/283 [00:11<00:00, 25.12it/s]


Epoch 184  average loss: 0.8716


Epoch 185: 100%|██████████| 283/283 [00:11<00:00, 24.35it/s]


Epoch 185  average loss: 0.8745


Epoch 186: 100%|██████████| 283/283 [00:11<00:00, 24.85it/s]


Epoch 186  average loss: 0.8708


Epoch 187: 100%|██████████| 283/283 [00:11<00:00, 25.17it/s]


Epoch 187  average loss: 0.8678


Epoch 188: 100%|██████████| 283/283 [00:11<00:00, 23.61it/s]


Epoch 188  average loss: 0.8695


Epoch 189: 100%|██████████| 283/283 [00:11<00:00, 23.99it/s]


Epoch 189  average loss: 0.8540


Epoch 190: 100%|██████████| 283/283 [00:11<00:00, 23.73it/s]


Epoch 190  average loss: 0.8676


Epoch 191: 100%|██████████| 283/283 [00:11<00:00, 24.47it/s]


Epoch 191  average loss: 0.8601


Epoch 192: 100%|██████████| 283/283 [00:11<00:00, 24.44it/s]


Epoch 192  average loss: 0.8706


Epoch 193: 100%|██████████| 283/283 [00:11<00:00, 24.27it/s]


Epoch 193  average loss: 0.8555


Epoch 194: 100%|██████████| 283/283 [00:11<00:00, 24.15it/s]


Epoch 194  average loss: 0.8545


Epoch 195: 100%|██████████| 283/283 [00:11<00:00, 24.30it/s]


Epoch 195  average loss: 0.8570


Epoch 196: 100%|██████████| 283/283 [00:11<00:00, 24.32it/s]


Epoch 196  average loss: 0.8626


Epoch 197: 100%|██████████| 283/283 [00:11<00:00, 24.30it/s]


Epoch 197  average loss: 0.8477


Epoch 198: 100%|██████████| 283/283 [00:11<00:00, 24.82it/s]


Epoch 198  average loss: 0.8501


Epoch 199: 100%|██████████| 283/283 [00:11<00:00, 24.12it/s]


Epoch 199  average loss: 0.8618


Epoch 200: 100%|██████████| 283/283 [00:11<00:00, 24.22it/s]


Epoch 200  average loss: 0.8605


Epoch 201: 100%|██████████| 283/283 [00:11<00:00, 23.88it/s]


Epoch 201  average loss: 0.8541


Epoch 202: 100%|██████████| 283/283 [00:11<00:00, 24.54it/s]


Epoch 202  average loss: 0.8576


Epoch 203: 100%|██████████| 283/283 [00:11<00:00, 24.30it/s]


Epoch 203  average loss: 0.8559


Epoch 204: 100%|██████████| 283/283 [00:11<00:00, 24.41it/s]


Epoch 204  average loss: 0.8471


Epoch 205: 100%|██████████| 283/283 [00:11<00:00, 24.28it/s]


Epoch 205  average loss: 0.8486


Epoch 206: 100%|██████████| 283/283 [00:11<00:00, 24.28it/s]


Epoch 206  average loss: 0.8483


Epoch 207: 100%|██████████| 283/283 [00:11<00:00, 23.73it/s]


Epoch 207  average loss: 0.8512


Epoch 208: 100%|██████████| 283/283 [00:11<00:00, 24.23it/s]


Epoch 208  average loss: 0.8504


Epoch 209: 100%|██████████| 283/283 [00:11<00:00, 23.95it/s]


Epoch 209  average loss: 0.8406


Epoch 210: 100%|██████████| 283/283 [00:11<00:00, 25.10it/s]


Epoch 210  average loss: 0.8458


Epoch 211: 100%|██████████| 283/283 [00:11<00:00, 24.10it/s]


Epoch 211  average loss: 0.8420


Epoch 212: 100%|██████████| 283/283 [00:11<00:00, 25.04it/s]


Epoch 212  average loss: 0.8415


Epoch 213: 100%|██████████| 283/283 [00:11<00:00, 24.95it/s]


Epoch 213  average loss: 0.8457


Epoch 214: 100%|██████████| 283/283 [00:11<00:00, 23.81it/s]


Epoch 214  average loss: 0.8455


Epoch 215: 100%|██████████| 283/283 [00:11<00:00, 24.34it/s]


Epoch 215  average loss: 0.8457


Epoch 216: 100%|██████████| 283/283 [00:11<00:00, 24.41it/s]


Epoch 216  average loss: 0.8421


Epoch 217: 100%|██████████| 283/283 [00:11<00:00, 23.84it/s]


Epoch 217  average loss: 0.8451


Epoch 218: 100%|██████████| 283/283 [00:11<00:00, 24.50it/s]


Epoch 218  average loss: 0.8386


Epoch 219: 100%|██████████| 283/283 [00:11<00:00, 24.62it/s]


Epoch 219  average loss: 0.8355


Epoch 220: 100%|██████████| 283/283 [00:11<00:00, 24.57it/s]


Epoch 220  average loss: 0.8329


Epoch 221: 100%|██████████| 283/283 [00:11<00:00, 24.48it/s]


Epoch 221  average loss: 0.8359


Epoch 222: 100%|██████████| 283/283 [00:11<00:00, 24.10it/s]


Epoch 222  average loss: 0.8444


Epoch 223: 100%|██████████| 283/283 [00:11<00:00, 25.20it/s]


Epoch 223  average loss: 0.8357


Epoch 224: 100%|██████████| 283/283 [00:11<00:00, 25.19it/s]


Epoch 224  average loss: 0.8371


Epoch 225: 100%|██████████| 283/283 [00:11<00:00, 24.82it/s]


Epoch 225  average loss: 0.8346


Epoch 226: 100%|██████████| 283/283 [00:11<00:00, 24.08it/s]


Epoch 226  average loss: 0.8317


Epoch 227: 100%|██████████| 283/283 [00:11<00:00, 24.30it/s]


Epoch 227  average loss: 0.8313


Epoch 228: 100%|██████████| 283/283 [00:11<00:00, 24.53it/s]


Epoch 228  average loss: 0.8400


Epoch 229: 100%|██████████| 283/283 [00:11<00:00, 24.22it/s]


Epoch 229  average loss: 0.8233


Epoch 230: 100%|██████████| 283/283 [00:11<00:00, 24.04it/s]


Epoch 230  average loss: 0.8320


Epoch 231: 100%|██████████| 283/283 [00:11<00:00, 24.54it/s]


Epoch 231  average loss: 0.8372


Epoch 232: 100%|██████████| 283/283 [00:11<00:00, 24.47it/s]


Epoch 232  average loss: 0.8359


Epoch 233: 100%|██████████| 283/283 [00:11<00:00, 23.64it/s]


Epoch 233  average loss: 0.8344


Epoch 234: 100%|██████████| 283/283 [00:11<00:00, 25.18it/s]


Epoch 234  average loss: 0.8263


Epoch 235: 100%|██████████| 283/283 [00:11<00:00, 25.17it/s]


Epoch 235  average loss: 0.8392


Epoch 236: 100%|██████████| 283/283 [00:11<00:00, 25.19it/s]


Epoch 236  average loss: 0.8375


Epoch 237: 100%|██████████| 283/283 [00:11<00:00, 24.49it/s]


Epoch 237  average loss: 0.8357


Epoch 238: 100%|██████████| 283/283 [00:11<00:00, 23.61it/s]


Epoch 238  average loss: 0.8255


Epoch 239: 100%|██████████| 283/283 [00:11<00:00, 23.63it/s]


Epoch 239  average loss: 0.8218


Epoch 240: 100%|██████████| 283/283 [00:11<00:00, 23.70it/s]


Epoch 240  average loss: 0.8215


Epoch 241: 100%|██████████| 283/283 [00:12<00:00, 23.51it/s]


Epoch 241  average loss: 0.8330


Epoch 242: 100%|██████████| 283/283 [00:11<00:00, 24.17it/s]


Epoch 242  average loss: 0.8322


Epoch 243: 100%|██████████| 283/283 [00:12<00:00, 23.55it/s]


Epoch 243  average loss: 0.8237


Epoch 244: 100%|██████████| 283/283 [00:11<00:00, 24.02it/s]


Epoch 244  average loss: 0.8277


Epoch 245: 100%|██████████| 283/283 [00:11<00:00, 24.38it/s]


Epoch 245  average loss: 0.8359


Epoch 246: 100%|██████████| 283/283 [00:11<00:00, 24.17it/s]


Epoch 246  average loss: 0.8270


Epoch 247: 100%|██████████| 283/283 [00:11<00:00, 24.99it/s]


Epoch 247  average loss: 0.8164


Epoch 248: 100%|██████████| 283/283 [00:11<00:00, 24.32it/s]


Epoch 248  average loss: 0.8286


Epoch 249: 100%|██████████| 283/283 [00:11<00:00, 24.50it/s]


Epoch 249  average loss: 0.8172


Epoch 250: 100%|██████████| 283/283 [00:11<00:00, 24.43it/s]


Epoch 250  average loss: 0.8261


Epoch 251: 100%|██████████| 283/283 [00:11<00:00, 24.87it/s]


Epoch 251  average loss: 0.8313


Epoch 252: 100%|██████████| 283/283 [00:11<00:00, 24.62it/s]


Epoch 252  average loss: 0.8184


Epoch 253: 100%|██████████| 283/283 [00:11<00:00, 24.95it/s]


Epoch 253  average loss: 0.8183


Epoch 254: 100%|██████████| 283/283 [00:11<00:00, 24.50it/s]


Epoch 254  average loss: 0.8236


Epoch 255: 100%|██████████| 283/283 [00:11<00:00, 24.97it/s]


Epoch 255  average loss: 0.8099


Epoch 256: 100%|██████████| 283/283 [00:11<00:00, 25.19it/s]


Epoch 256  average loss: 0.8198


Epoch 257: 100%|██████████| 283/283 [00:11<00:00, 24.10it/s]


Epoch 257  average loss: 0.8225


Epoch 258: 100%|██████████| 283/283 [00:11<00:00, 23.78it/s]


Epoch 258  average loss: 0.8164


Epoch 259: 100%|██████████| 283/283 [00:11<00:00, 24.18it/s]


Epoch 259  average loss: 0.8222


Epoch 260: 100%|██████████| 283/283 [00:11<00:00, 24.15it/s]


Epoch 260  average loss: 0.8073


Epoch 261: 100%|██████████| 283/283 [00:11<00:00, 24.90it/s]


Epoch 261  average loss: 0.8120


Epoch 262: 100%|██████████| 283/283 [00:11<00:00, 24.01it/s]


Epoch 262  average loss: 0.8161


Epoch 263: 100%|██████████| 283/283 [00:11<00:00, 24.43it/s]


Epoch 263  average loss: 0.8122


Epoch 264: 100%|██████████| 283/283 [00:11<00:00, 24.38it/s]


Epoch 264  average loss: 0.8139


Epoch 265: 100%|██████████| 283/283 [00:11<00:00, 24.89it/s]


Epoch 265  average loss: 0.8184


Epoch 266: 100%|██████████| 283/283 [00:11<00:00, 24.92it/s]


Epoch 266  average loss: 0.8210


Epoch 267: 100%|██████████| 283/283 [00:11<00:00, 24.43it/s]


Epoch 267  average loss: 0.8135


Epoch 268: 100%|██████████| 283/283 [00:11<00:00, 24.77it/s]


Epoch 268  average loss: 0.8124


Epoch 269: 100%|██████████| 283/283 [00:11<00:00, 24.42it/s]


Epoch 269  average loss: 0.8146


Epoch 270: 100%|██████████| 283/283 [00:12<00:00, 23.53it/s]


Epoch 270  average loss: 0.8058


Epoch 271: 100%|██████████| 283/283 [00:11<00:00, 24.29it/s]


Epoch 271  average loss: 0.8123


Epoch 272: 100%|██████████| 283/283 [00:11<00:00, 24.40it/s]


Epoch 272  average loss: 0.8035


Epoch 273: 100%|██████████| 283/283 [00:11<00:00, 25.20it/s]


Epoch 273  average loss: 0.8105


Epoch 274: 100%|██████████| 283/283 [00:11<00:00, 25.15it/s]


Epoch 274  average loss: 0.8111


Epoch 275: 100%|██████████| 283/283 [00:11<00:00, 23.72it/s]


Epoch 275  average loss: 0.8068


Epoch 276: 100%|██████████| 283/283 [00:11<00:00, 24.76it/s]


Epoch 276  average loss: 0.8083


Epoch 277: 100%|██████████| 283/283 [00:11<00:00, 24.73it/s]


Epoch 277  average loss: 0.8146


Epoch 278: 100%|██████████| 283/283 [00:11<00:00, 24.43it/s]


Epoch 278  average loss: 0.8049


Epoch 279: 100%|██████████| 283/283 [00:11<00:00, 24.13it/s]


Epoch 279  average loss: 0.8008


Epoch 280: 100%|██████████| 283/283 [00:11<00:00, 23.71it/s]


Epoch 280  average loss: 0.8030


Epoch 281: 100%|██████████| 283/283 [00:11<00:00, 23.81it/s]


Epoch 281  average loss: 0.8095


Epoch 282: 100%|██████████| 283/283 [00:11<00:00, 24.05it/s]


Epoch 282  average loss: 0.8080


Epoch 283: 100%|██████████| 283/283 [00:11<00:00, 23.98it/s]


Epoch 283  average loss: 0.8114


Epoch 284: 100%|██████████| 283/283 [00:11<00:00, 24.24it/s]


Epoch 284  average loss: 0.7935


Epoch 285: 100%|██████████| 283/283 [00:11<00:00, 24.44it/s]


Epoch 285  average loss: 0.8005


Epoch 286: 100%|██████████| 283/283 [00:11<00:00, 24.81it/s]


Epoch 286  average loss: 0.8018


Epoch 287: 100%|██████████| 283/283 [00:11<00:00, 24.76it/s]


Epoch 287  average loss: 0.8130


Epoch 288: 100%|██████████| 283/283 [00:11<00:00, 24.91it/s]


Epoch 288  average loss: 0.8045


Epoch 289: 100%|██████████| 283/283 [00:11<00:00, 24.95it/s]


Epoch 289  average loss: 0.8056


Epoch 290: 100%|██████████| 283/283 [00:11<00:00, 24.56it/s]


Epoch 290  average loss: 0.7889


Epoch 291: 100%|██████████| 283/283 [00:11<00:00, 24.52it/s]


Epoch 291  average loss: 0.8075


Epoch 292: 100%|██████████| 283/283 [00:11<00:00, 24.69it/s]


Epoch 292  average loss: 0.7883


Epoch 293: 100%|██████████| 283/283 [00:11<00:00, 24.73it/s]


Epoch 293  average loss: 0.7994


Epoch 294: 100%|██████████| 283/283 [00:11<00:00, 25.15it/s]


Epoch 294  average loss: 0.8048


Epoch 295: 100%|██████████| 283/283 [00:11<00:00, 25.02it/s]


Epoch 295  average loss: 0.8039


Epoch 296: 100%|██████████| 283/283 [00:11<00:00, 24.49it/s]


Epoch 296  average loss: 0.8038


Epoch 297: 100%|██████████| 283/283 [00:11<00:00, 25.20it/s]


Epoch 297  average loss: 0.8021


Epoch 298: 100%|██████████| 283/283 [00:11<00:00, 25.25it/s]


Epoch 298  average loss: 0.8109


Epoch 299: 100%|██████████| 283/283 [00:11<00:00, 25.03it/s]


Epoch 299  average loss: 0.8143


Epoch 300: 100%|██████████| 283/283 [00:11<00:00, 24.88it/s]


Epoch 300  average loss: 0.7946


Epoch 301: 100%|██████████| 283/283 [00:11<00:00, 24.95it/s]


Epoch 301  average loss: 0.8022


Epoch 302: 100%|██████████| 283/283 [00:11<00:00, 24.39it/s]


Epoch 302  average loss: 0.7928


Epoch 303: 100%|██████████| 283/283 [00:11<00:00, 24.31it/s]


Epoch 303  average loss: 0.7947


Epoch 304: 100%|██████████| 283/283 [00:11<00:00, 24.70it/s]


Epoch 304  average loss: 0.7895


Epoch 305: 100%|██████████| 283/283 [00:11<00:00, 24.51it/s]


Epoch 305  average loss: 0.7935


Epoch 306: 100%|██████████| 283/283 [00:11<00:00, 24.84it/s]


Epoch 306  average loss: 0.8015


Epoch 307: 100%|██████████| 283/283 [00:11<00:00, 24.66it/s]


Epoch 307  average loss: 0.7995


Epoch 308: 100%|██████████| 283/283 [00:11<00:00, 24.72it/s]


Epoch 308  average loss: 0.8006


Epoch 309: 100%|██████████| 283/283 [00:11<00:00, 25.24it/s]


Epoch 309  average loss: 0.7941


Epoch 310: 100%|██████████| 283/283 [00:11<00:00, 24.51it/s]


Epoch 310  average loss: 0.7895


Epoch 311: 100%|██████████| 283/283 [00:11<00:00, 24.27it/s]


Epoch 311  average loss: 0.7918


Epoch 312: 100%|██████████| 283/283 [00:11<00:00, 24.48it/s]


Epoch 312  average loss: 0.7897


Epoch 313: 100%|██████████| 283/283 [00:11<00:00, 24.76it/s]


Epoch 313  average loss: 0.7996


Epoch 314: 100%|██████████| 283/283 [00:11<00:00, 24.65it/s]


Epoch 314  average loss: 0.7861


Epoch 315: 100%|██████████| 283/283 [00:11<00:00, 24.93it/s]


Epoch 315  average loss: 0.7869


Epoch 316: 100%|██████████| 283/283 [00:11<00:00, 24.78it/s]


Epoch 316  average loss: 0.7874


Epoch 317: 100%|██████████| 283/283 [00:11<00:00, 24.72it/s]


Epoch 317  average loss: 0.7951


Epoch 318: 100%|██████████| 283/283 [00:11<00:00, 24.58it/s]


Epoch 318  average loss: 0.7969


Epoch 319: 100%|██████████| 283/283 [00:11<00:00, 24.82it/s]


Epoch 319  average loss: 0.7911


Epoch 320: 100%|██████████| 283/283 [00:11<00:00, 24.89it/s]


Epoch 320  average loss: 0.7936


Epoch 321: 100%|██████████| 283/283 [00:11<00:00, 24.62it/s]


Epoch 321  average loss: 0.7904


Epoch 322: 100%|██████████| 283/283 [00:11<00:00, 24.78it/s]


Epoch 322  average loss: 0.7950


Epoch 323: 100%|██████████| 283/283 [00:11<00:00, 24.74it/s]


Epoch 323  average loss: 0.7885


Epoch 324: 100%|██████████| 283/283 [00:11<00:00, 25.14it/s]


Epoch 324  average loss: 0.7790


Epoch 325: 100%|██████████| 283/283 [00:11<00:00, 24.85it/s]


Epoch 325  average loss: 0.7860


Epoch 326: 100%|██████████| 283/283 [00:11<00:00, 24.58it/s]


Epoch 326  average loss: 0.8031


Epoch 327: 100%|██████████| 283/283 [00:11<00:00, 25.06it/s]


Epoch 327  average loss: 0.7861


Epoch 328: 100%|██████████| 283/283 [00:11<00:00, 24.44it/s]


Epoch 328  average loss: 0.7896


Epoch 329: 100%|██████████| 283/283 [00:11<00:00, 24.47it/s]


Epoch 329  average loss: 0.7852


Epoch 330: 100%|██████████| 283/283 [00:11<00:00, 24.52it/s]


Epoch 330  average loss: 0.7906


Epoch 331: 100%|██████████| 283/283 [00:11<00:00, 24.39it/s]


Epoch 331  average loss: 0.7794


Epoch 332: 100%|██████████| 283/283 [00:11<00:00, 25.05it/s]


Epoch 332  average loss: 0.7862


Epoch 333: 100%|██████████| 283/283 [00:11<00:00, 24.28it/s]


Epoch 333  average loss: 0.7822


Epoch 334: 100%|██████████| 283/283 [00:11<00:00, 24.81it/s]


Epoch 334  average loss: 0.7848


Epoch 335: 100%|██████████| 283/283 [00:11<00:00, 25.12it/s]


Epoch 335  average loss: 0.7884


Epoch 336: 100%|██████████| 283/283 [00:11<00:00, 24.87it/s]


Epoch 336  average loss: 0.7958


Epoch 337: 100%|██████████| 283/283 [00:11<00:00, 24.62it/s]


Epoch 337  average loss: 0.7882


Epoch 338: 100%|██████████| 283/283 [00:11<00:00, 24.36it/s]


Epoch 338  average loss: 0.7763


Epoch 339: 100%|██████████| 283/283 [00:11<00:00, 24.99it/s]


Epoch 339  average loss: 0.7980


Epoch 340: 100%|██████████| 283/283 [00:11<00:00, 24.89it/s]


Epoch 340  average loss: 0.7853


Epoch 341: 100%|██████████| 283/283 [00:11<00:00, 24.93it/s]


Epoch 341  average loss: 0.7848


Epoch 342: 100%|██████████| 283/283 [00:11<00:00, 25.20it/s]


Epoch 342  average loss: 0.7825


Epoch 343: 100%|██████████| 283/283 [00:11<00:00, 25.07it/s]


Epoch 343  average loss: 0.7791


Epoch 344: 100%|██████████| 283/283 [00:11<00:00, 24.40it/s]


Epoch 344  average loss: 0.7847


Epoch 345: 100%|██████████| 283/283 [00:11<00:00, 24.38it/s]


Epoch 345  average loss: 0.7832


Epoch 346: 100%|██████████| 283/283 [00:11<00:00, 25.21it/s]


Epoch 346  average loss: 0.7848


Epoch 347: 100%|██████████| 283/283 [00:11<00:00, 24.77it/s]


Epoch 347  average loss: 0.7958


Epoch 348: 100%|██████████| 283/283 [00:11<00:00, 24.34it/s]


Epoch 348  average loss: 0.7785


Epoch 349: 100%|██████████| 283/283 [00:11<00:00, 24.46it/s]


Epoch 349  average loss: 0.7814


Epoch 350: 100%|██████████| 283/283 [00:11<00:00, 24.31it/s]


Epoch 350  average loss: 0.7773


Epoch 351: 100%|██████████| 283/283 [00:11<00:00, 24.54it/s]


Epoch 351  average loss: 0.7747


Epoch 352: 100%|██████████| 283/283 [00:11<00:00, 24.18it/s]


Epoch 352  average loss: 0.7852


Epoch 353: 100%|██████████| 283/283 [00:11<00:00, 24.41it/s]


Epoch 353  average loss: 0.7884


Epoch 354: 100%|██████████| 283/283 [00:11<00:00, 24.25it/s]


Epoch 354  average loss: 0.7835


Epoch 355: 100%|██████████| 283/283 [00:11<00:00, 23.72it/s]


Epoch 355  average loss: 0.7797


Epoch 356: 100%|██████████| 283/283 [00:11<00:00, 24.81it/s]


Epoch 356  average loss: 0.7730


Epoch 357: 100%|██████████| 283/283 [00:11<00:00, 24.78it/s]


Epoch 357  average loss: 0.7799


Epoch 358: 100%|██████████| 283/283 [00:11<00:00, 24.38it/s]


Epoch 358  average loss: 0.7766


Epoch 359: 100%|██████████| 283/283 [00:11<00:00, 24.34it/s]


Epoch 359  average loss: 0.7802


Epoch 360: 100%|██████████| 283/283 [00:11<00:00, 23.88it/s]

Epoch 360  average loss: 0.7839





In [9]:
# contrastive loss spconv

config_max_epochs = 300
config_device = 'cuda'
config_margin = 1.0
config_lambda_c = 0.1

train_dl = DataLoader(train_ds,batch_size=1,shuffle=True, collate_fn=lambda x: x[0])
# test_dl = DataLoader(test_ds,batch_size=1, collate_fn=lambda x: x[0])

model = SPConvModel(
    in_dim = 3+12,
    num_classes=89,
    pn_hidden=64,
    super_hidden=64,
    pn_dropout=0.1,
    super_dropout=0.2,
    edge_dim=32,
    super_num_conv=3
).to(device=config_device)

optim = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

for epoch in range(1,config_max_epochs+1):
    model.train()
    total_loss = 0.
    for d in tqdm(train_dl,total=len(train_ds), desc=f"Epoch {epoch:03d}"):
        pos         = d['pos'].to(device=config_device)
        point_feats = d['point_feats'].to(device=config_device)
        labels      = d['labels'].to(device=config_device)
        super_idx   = d['superpoint_idx'].to(device=config_device)
        edge_idx    = d['edge_index'].to(device=config_device)

        S = super_idx.amax() + 1 
        logits, z = model(pos, point_feats, super_idx, edge_idx)
        L_class = F.cross_entropy(logits, labels)

        # contrastive loss
        src, dst = edge_idx
        zi, zj = z[src], z[dst]
        dist2 = (zi - zj).pow(2).sum(dim=1)
        same = (labels[src] == labels[dst]).float() # [E]
        # pull
        loss_pos = same * dist2
        # push
        loss_neg = (1-same) * F.relu(config_margin - dist2)

        L_contrast = (loss_pos + loss_neg).mean()

        # total loss
        loss = L_class + config_lambda_c * L_contrast
        optim.zero_grad()
        loss.backward()
        optim.step()
        total_loss += loss.item()

    print(f"Epoch {epoch:02d}  average loss: {total_loss/len(train_dl):.4f}")

torch.save(model, './experiments/spconv_contrastive_L0-1.pth')

Epoch 001:   0%|          | 0/283 [00:00<?, ?it/s]

Epoch 001: 100%|██████████| 283/283 [00:14<00:00, 20.21it/s]


Epoch 01  average loss: 4.2621


Epoch 002: 100%|██████████| 283/283 [00:14<00:00, 20.13it/s]


Epoch 02  average loss: 3.0484


Epoch 003: 100%|██████████| 283/283 [00:13<00:00, 21.75it/s]


Epoch 03  average loss: 2.8538


Epoch 004: 100%|██████████| 283/283 [00:14<00:00, 20.18it/s]


Epoch 04  average loss: 2.7604


Epoch 005: 100%|██████████| 283/283 [00:11<00:00, 24.23it/s]


Epoch 05  average loss: 2.6085


Epoch 006: 100%|██████████| 283/283 [00:12<00:00, 21.92it/s]


Epoch 06  average loss: 2.5084


Epoch 007: 100%|██████████| 283/283 [00:13<00:00, 20.94it/s]


Epoch 07  average loss: 2.2589


Epoch 008: 100%|██████████| 283/283 [00:13<00:00, 21.52it/s]


Epoch 08  average loss: 2.0962


Epoch 009: 100%|██████████| 283/283 [00:13<00:00, 20.48it/s]


Epoch 09  average loss: 1.9994


Epoch 010: 100%|██████████| 283/283 [00:13<00:00, 20.66it/s]


Epoch 10  average loss: 1.9247


Epoch 011: 100%|██████████| 283/283 [00:13<00:00, 21.49it/s]


Epoch 11  average loss: 1.8647


Epoch 012: 100%|██████████| 283/283 [00:13<00:00, 21.74it/s]


Epoch 12  average loss: 1.8417


Epoch 013: 100%|██████████| 283/283 [00:13<00:00, 20.81it/s]


Epoch 13  average loss: 1.7940


Epoch 014: 100%|██████████| 283/283 [00:12<00:00, 22.18it/s]


Epoch 14  average loss: 1.7601


Epoch 015: 100%|██████████| 283/283 [00:13<00:00, 21.31it/s]


Epoch 15  average loss: 1.7276


Epoch 016: 100%|██████████| 283/283 [00:12<00:00, 21.83it/s]


Epoch 16  average loss: 1.6848


Epoch 017: 100%|██████████| 283/283 [00:12<00:00, 22.07it/s]


Epoch 17  average loss: 1.6753


Epoch 018: 100%|██████████| 283/283 [00:13<00:00, 20.97it/s]


Epoch 18  average loss: 1.6520


Epoch 019: 100%|██████████| 283/283 [00:12<00:00, 21.83it/s]


Epoch 19  average loss: 1.6118


Epoch 020: 100%|██████████| 283/283 [00:11<00:00, 24.37it/s]


Epoch 20  average loss: 1.6000


Epoch 021: 100%|██████████| 283/283 [00:12<00:00, 21.93it/s]


Epoch 21  average loss: 1.6002


Epoch 022: 100%|██████████| 283/283 [00:13<00:00, 21.03it/s]


Epoch 22  average loss: 1.5806


Epoch 023: 100%|██████████| 283/283 [00:13<00:00, 21.08it/s]


Epoch 23  average loss: 1.5482


Epoch 024: 100%|██████████| 283/283 [00:11<00:00, 24.60it/s]


Epoch 24  average loss: 1.5217


Epoch 025: 100%|██████████| 283/283 [00:12<00:00, 23.23it/s]


Epoch 25  average loss: 1.5197


Epoch 026: 100%|██████████| 283/283 [00:12<00:00, 23.18it/s]


Epoch 26  average loss: 1.5031


Epoch 027: 100%|██████████| 283/283 [00:12<00:00, 23.37it/s]


Epoch 27  average loss: 1.4967


Epoch 028: 100%|██████████| 283/283 [00:12<00:00, 23.45it/s]


Epoch 28  average loss: 1.5138


Epoch 029: 100%|██████████| 283/283 [00:11<00:00, 23.76it/s]


Epoch 29  average loss: 1.4763


Epoch 030: 100%|██████████| 283/283 [00:13<00:00, 20.44it/s]


Epoch 30  average loss: 1.4442


Epoch 031: 100%|██████████| 283/283 [00:12<00:00, 22.91it/s]


Epoch 31  average loss: 1.4601


Epoch 032: 100%|██████████| 283/283 [00:13<00:00, 20.72it/s]


Epoch 32  average loss: 1.4201


Epoch 033: 100%|██████████| 283/283 [00:13<00:00, 21.48it/s]


Epoch 33  average loss: 1.4330


Epoch 034: 100%|██████████| 283/283 [00:12<00:00, 22.37it/s]


Epoch 34  average loss: 1.4252


Epoch 035: 100%|██████████| 283/283 [00:13<00:00, 21.07it/s]


Epoch 35  average loss: 1.4231


Epoch 036: 100%|██████████| 283/283 [00:13<00:00, 21.21it/s]


Epoch 36  average loss: 1.3876


Epoch 037: 100%|██████████| 283/283 [00:13<00:00, 21.27it/s]


Epoch 37  average loss: 1.4045


Epoch 038: 100%|██████████| 283/283 [00:13<00:00, 21.42it/s]


Epoch 38  average loss: 1.3717


Epoch 039: 100%|██████████| 283/283 [00:13<00:00, 20.81it/s]


Epoch 39  average loss: 1.3772


Epoch 040: 100%|██████████| 283/283 [00:13<00:00, 21.08it/s]


Epoch 40  average loss: 1.3588


Epoch 041: 100%|██████████| 283/283 [00:13<00:00, 21.17it/s]


Epoch 41  average loss: 1.3412


Epoch 042: 100%|██████████| 283/283 [00:13<00:00, 20.27it/s]


Epoch 42  average loss: 1.3355


Epoch 043: 100%|██████████| 283/283 [00:13<00:00, 20.70it/s]


Epoch 43  average loss: 1.3445


Epoch 044: 100%|██████████| 283/283 [00:13<00:00, 21.15it/s]


Epoch 44  average loss: 1.3218


Epoch 045: 100%|██████████| 283/283 [00:13<00:00, 21.50it/s]


Epoch 45  average loss: 1.3066


Epoch 046: 100%|██████████| 283/283 [00:13<00:00, 21.16it/s]


Epoch 46  average loss: 1.3120


Epoch 047: 100%|██████████| 283/283 [00:11<00:00, 24.28it/s]


Epoch 47  average loss: 1.3062


Epoch 048: 100%|██████████| 283/283 [00:13<00:00, 20.74it/s]


Epoch 48  average loss: 1.2874


Epoch 049: 100%|██████████| 283/283 [00:14<00:00, 19.95it/s]


Epoch 49  average loss: 1.3048


Epoch 050: 100%|██████████| 283/283 [00:11<00:00, 24.20it/s]


Epoch 50  average loss: 1.3010


Epoch 051: 100%|██████████| 283/283 [00:13<00:00, 21.53it/s]


Epoch 51  average loss: 1.3019


Epoch 052: 100%|██████████| 283/283 [00:12<00:00, 22.17it/s]


Epoch 52  average loss: 1.2581


Epoch 053: 100%|██████████| 283/283 [00:12<00:00, 23.44it/s]


Epoch 53  average loss: 1.2528


Epoch 054: 100%|██████████| 283/283 [00:13<00:00, 21.34it/s]


Epoch 54  average loss: 1.2515


Epoch 055: 100%|██████████| 283/283 [00:13<00:00, 21.28it/s]


Epoch 55  average loss: 1.2531


Epoch 056: 100%|██████████| 283/283 [00:13<00:00, 20.70it/s]


Epoch 56  average loss: 1.2294


Epoch 057: 100%|██████████| 283/283 [00:13<00:00, 20.78it/s]


Epoch 57  average loss: 1.2413


Epoch 058: 100%|██████████| 283/283 [00:12<00:00, 23.02it/s]


Epoch 58  average loss: 1.2248


Epoch 059: 100%|██████████| 283/283 [00:12<00:00, 23.54it/s]


Epoch 59  average loss: 1.2268


Epoch 060: 100%|██████████| 283/283 [00:12<00:00, 21.97it/s]


Epoch 60  average loss: 1.2117


Epoch 061: 100%|██████████| 283/283 [00:13<00:00, 20.83it/s]


Epoch 61  average loss: 1.2286


Epoch 062: 100%|██████████| 283/283 [00:13<00:00, 21.17it/s]


Epoch 62  average loss: 1.2090


Epoch 063: 100%|██████████| 283/283 [00:13<00:00, 20.34it/s]


Epoch 63  average loss: 1.1873


Epoch 064: 100%|██████████| 283/283 [00:11<00:00, 23.79it/s]


Epoch 64  average loss: 1.1837


Epoch 065: 100%|██████████| 283/283 [00:12<00:00, 22.05it/s]


Epoch 65  average loss: 1.1805


Epoch 066: 100%|██████████| 283/283 [00:11<00:00, 23.86it/s]


Epoch 66  average loss: 1.1721


Epoch 067: 100%|██████████| 283/283 [00:13<00:00, 20.35it/s]


Epoch 67  average loss: 1.1813


Epoch 068: 100%|██████████| 283/283 [00:13<00:00, 20.58it/s]


Epoch 68  average loss: 1.1431


Epoch 069: 100%|██████████| 283/283 [00:12<00:00, 22.93it/s]


Epoch 69  average loss: 1.1764


Epoch 070: 100%|██████████| 283/283 [00:11<00:00, 24.23it/s]


Epoch 70  average loss: 1.1642


Epoch 071: 100%|██████████| 283/283 [00:11<00:00, 23.98it/s]


Epoch 71  average loss: 1.1455


Epoch 072: 100%|██████████| 283/283 [00:13<00:00, 21.61it/s]


Epoch 72  average loss: 1.1374


Epoch 073: 100%|██████████| 283/283 [00:13<00:00, 20.47it/s]


Epoch 73  average loss: 1.1280


Epoch 074: 100%|██████████| 283/283 [00:12<00:00, 22.67it/s]


Epoch 74  average loss: 1.1211


Epoch 075: 100%|██████████| 283/283 [00:13<00:00, 21.15it/s]


Epoch 75  average loss: 1.1245


Epoch 076: 100%|██████████| 283/283 [00:11<00:00, 23.59it/s]


Epoch 76  average loss: 1.1072


Epoch 077: 100%|██████████| 283/283 [00:13<00:00, 21.43it/s]


Epoch 77  average loss: 1.0868


Epoch 078: 100%|██████████| 283/283 [00:14<00:00, 19.68it/s]


Epoch 78  average loss: 1.1153


Epoch 079: 100%|██████████| 283/283 [00:14<00:00, 20.10it/s]


Epoch 79  average loss: 1.0956


Epoch 080: 100%|██████████| 283/283 [00:13<00:00, 20.64it/s]


Epoch 80  average loss: 1.0942


Epoch 081: 100%|██████████| 283/283 [00:13<00:00, 20.86it/s]


Epoch 81  average loss: 1.0797


Epoch 082: 100%|██████████| 283/283 [00:13<00:00, 21.58it/s]


Epoch 82  average loss: 1.1157


Epoch 083: 100%|██████████| 283/283 [00:13<00:00, 20.88it/s]


Epoch 83  average loss: 1.0570


Epoch 084: 100%|██████████| 283/283 [00:13<00:00, 20.30it/s]


Epoch 84  average loss: 1.0554


Epoch 085: 100%|██████████| 283/283 [00:13<00:00, 21.28it/s]


Epoch 85  average loss: 1.0632


Epoch 086: 100%|██████████| 283/283 [00:12<00:00, 23.27it/s]


Epoch 86  average loss: 1.0777


Epoch 087: 100%|██████████| 283/283 [00:13<00:00, 20.71it/s]


Epoch 87  average loss: 1.0398


Epoch 088: 100%|██████████| 283/283 [00:13<00:00, 21.25it/s]


Epoch 88  average loss: 1.0572


Epoch 089: 100%|██████████| 283/283 [00:12<00:00, 22.66it/s]


Epoch 89  average loss: 1.0581


Epoch 090: 100%|██████████| 283/283 [00:13<00:00, 21.29it/s]


Epoch 90  average loss: 1.0396


Epoch 091: 100%|██████████| 283/283 [00:13<00:00, 21.48it/s]


Epoch 91  average loss: 1.0289


Epoch 092: 100%|██████████| 283/283 [00:14<00:00, 20.02it/s]


Epoch 92  average loss: 1.0593


Epoch 093: 100%|██████████| 283/283 [00:13<00:00, 21.64it/s]


Epoch 93  average loss: 1.0388


Epoch 094: 100%|██████████| 283/283 [00:12<00:00, 23.04it/s]


Epoch 94  average loss: 1.0267


Epoch 095: 100%|██████████| 283/283 [00:13<00:00, 21.12it/s]


Epoch 95  average loss: 1.0109


Epoch 096: 100%|██████████| 283/283 [00:13<00:00, 21.22it/s]


Epoch 96  average loss: 1.0388


Epoch 097: 100%|██████████| 283/283 [00:13<00:00, 20.25it/s]


Epoch 97  average loss: 1.0352


Epoch 098: 100%|██████████| 283/283 [00:13<00:00, 20.56it/s]


Epoch 98  average loss: 1.0043


Epoch 099: 100%|██████████| 283/283 [00:14<00:00, 19.79it/s]


Epoch 99  average loss: 0.9950


Epoch 100: 100%|██████████| 283/283 [00:13<00:00, 21.16it/s]


Epoch 100  average loss: 0.9963


Epoch 101: 100%|██████████| 283/283 [00:13<00:00, 20.46it/s]


Epoch 101  average loss: 1.0100


Epoch 102: 100%|██████████| 283/283 [00:13<00:00, 21.26it/s]


Epoch 102  average loss: 0.9803


Epoch 103: 100%|██████████| 283/283 [00:13<00:00, 21.07it/s]


Epoch 103  average loss: 1.0029


Epoch 104: 100%|██████████| 283/283 [00:13<00:00, 20.63it/s]


Epoch 104  average loss: 1.0466


Epoch 105: 100%|██████████| 283/283 [00:13<00:00, 21.07it/s]


Epoch 105  average loss: 0.9680


Epoch 106: 100%|██████████| 283/283 [00:13<00:00, 20.28it/s]


Epoch 106  average loss: 0.9560


Epoch 107: 100%|██████████| 283/283 [00:14<00:00, 19.87it/s]


Epoch 107  average loss: 0.9623


Epoch 108: 100%|██████████| 283/283 [00:13<00:00, 20.56it/s]


Epoch 108  average loss: 0.9949


Epoch 109: 100%|██████████| 283/283 [00:13<00:00, 20.46it/s]


Epoch 109  average loss: 0.9944


Epoch 110: 100%|██████████| 283/283 [00:13<00:00, 21.38it/s]


Epoch 110  average loss: 0.9570


Epoch 111: 100%|██████████| 283/283 [00:14<00:00, 19.87it/s]


Epoch 111  average loss: 0.9608


Epoch 112: 100%|██████████| 283/283 [00:14<00:00, 19.94it/s]


Epoch 112  average loss: 0.9463


Epoch 113: 100%|██████████| 283/283 [00:13<00:00, 20.69it/s]


Epoch 113  average loss: 0.9536


Epoch 114: 100%|██████████| 283/283 [00:13<00:00, 20.78it/s]


Epoch 114  average loss: 0.9700


Epoch 115: 100%|██████████| 283/283 [00:14<00:00, 20.10it/s]


Epoch 115  average loss: 0.9331


Epoch 116: 100%|██████████| 283/283 [00:14<00:00, 19.99it/s]


Epoch 116  average loss: 0.9369


Epoch 117: 100%|██████████| 283/283 [00:11<00:00, 24.08it/s]


Epoch 117  average loss: 0.9398


Epoch 118: 100%|██████████| 283/283 [00:12<00:00, 22.53it/s]


Epoch 118  average loss: 0.9340


Epoch 119: 100%|██████████| 283/283 [00:12<00:00, 21.86it/s]


Epoch 119  average loss: 0.9373


Epoch 120: 100%|██████████| 283/283 [00:13<00:00, 20.76it/s]


Epoch 120  average loss: 0.9485


Epoch 121: 100%|██████████| 283/283 [00:13<00:00, 20.49it/s]


Epoch 121  average loss: 0.9389


Epoch 122: 100%|██████████| 283/283 [00:14<00:00, 19.96it/s]


Epoch 122  average loss: 0.9069


Epoch 123: 100%|██████████| 283/283 [00:14<00:00, 19.82it/s]


Epoch 123  average loss: 0.9100


Epoch 124: 100%|██████████| 283/283 [00:12<00:00, 23.11it/s]


Epoch 124  average loss: 0.9264


Epoch 125: 100%|██████████| 283/283 [00:12<00:00, 21.84it/s]


Epoch 125  average loss: 0.9140


Epoch 126: 100%|██████████| 283/283 [00:13<00:00, 20.25it/s]


Epoch 126  average loss: 0.9015


Epoch 127: 100%|██████████| 283/283 [00:13<00:00, 20.41it/s]


Epoch 127  average loss: 0.9198


Epoch 128: 100%|██████████| 283/283 [00:14<00:00, 19.92it/s]


Epoch 128  average loss: 0.9052


Epoch 129: 100%|██████████| 283/283 [00:14<00:00, 19.68it/s]


Epoch 129  average loss: 0.8921


Epoch 130: 100%|██████████| 283/283 [00:13<00:00, 21.30it/s]


Epoch 130  average loss: 0.9065


Epoch 131: 100%|██████████| 283/283 [00:12<00:00, 23.40it/s]


Epoch 131  average loss: 0.8902


Epoch 132: 100%|██████████| 283/283 [00:13<00:00, 20.96it/s]


Epoch 132  average loss: 0.9085


Epoch 133: 100%|██████████| 283/283 [00:13<00:00, 21.43it/s]


Epoch 133  average loss: 0.8969


Epoch 134: 100%|██████████| 283/283 [00:13<00:00, 20.66it/s]


Epoch 134  average loss: 0.8670


Epoch 135: 100%|██████████| 283/283 [00:12<00:00, 22.14it/s]


Epoch 135  average loss: 0.8718


Epoch 136: 100%|██████████| 283/283 [00:13<00:00, 20.26it/s]


Epoch 136  average loss: 0.8627


Epoch 137: 100%|██████████| 283/283 [00:14<00:00, 20.05it/s]


Epoch 137  average loss: 0.8763


Epoch 138: 100%|██████████| 283/283 [00:14<00:00, 19.93it/s]


Epoch 138  average loss: 0.8999


Epoch 139: 100%|██████████| 283/283 [00:12<00:00, 22.00it/s]


Epoch 139  average loss: 0.8727


Epoch 140: 100%|██████████| 283/283 [00:12<00:00, 22.72it/s]


Epoch 140  average loss: 0.8612


Epoch 141: 100%|██████████| 283/283 [00:12<00:00, 23.00it/s]


Epoch 141  average loss: 0.8602


Epoch 142: 100%|██████████| 283/283 [00:13<00:00, 20.27it/s]


Epoch 142  average loss: 0.8514


Epoch 143: 100%|██████████| 283/283 [00:13<00:00, 21.36it/s]


Epoch 143  average loss: 0.8497


Epoch 144: 100%|██████████| 283/283 [00:12<00:00, 23.16it/s]


Epoch 144  average loss: 0.8502


Epoch 145: 100%|██████████| 283/283 [00:12<00:00, 21.91it/s]


Epoch 145  average loss: 0.8438


Epoch 146: 100%|██████████| 283/283 [00:11<00:00, 24.66it/s]


Epoch 146  average loss: 0.8660


Epoch 147: 100%|██████████| 283/283 [00:13<00:00, 21.10it/s]


Epoch 147  average loss: 0.8405


Epoch 148: 100%|██████████| 283/283 [00:12<00:00, 23.05it/s]


Epoch 148  average loss: 0.8471


Epoch 149: 100%|██████████| 283/283 [00:12<00:00, 22.38it/s]


Epoch 149  average loss: 0.8459


Epoch 150: 100%|██████████| 283/283 [00:12<00:00, 22.51it/s]


Epoch 150  average loss: 0.8076


Epoch 151: 100%|██████████| 283/283 [00:12<00:00, 22.37it/s]


Epoch 151  average loss: 0.8307


Epoch 152: 100%|██████████| 283/283 [00:13<00:00, 20.84it/s]


Epoch 152  average loss: 0.8761


Epoch 153: 100%|██████████| 283/283 [00:12<00:00, 21.94it/s]


Epoch 153  average loss: 0.8319


Epoch 154: 100%|██████████| 283/283 [00:12<00:00, 23.38it/s]


Epoch 154  average loss: 0.8134


Epoch 155: 100%|██████████| 283/283 [00:13<00:00, 20.82it/s]


Epoch 155  average loss: 0.8320


Epoch 156: 100%|██████████| 283/283 [00:13<00:00, 21.70it/s]


Epoch 156  average loss: 0.8517


Epoch 157: 100%|██████████| 283/283 [00:13<00:00, 20.61it/s]


Epoch 157  average loss: 0.8088


Epoch 158: 100%|██████████| 283/283 [00:13<00:00, 20.95it/s]


Epoch 158  average loss: 0.8090


Epoch 159: 100%|██████████| 283/283 [00:12<00:00, 23.55it/s]


Epoch 159  average loss: 0.8003


Epoch 160: 100%|██████████| 283/283 [00:12<00:00, 21.82it/s]


Epoch 160  average loss: 0.8117


Epoch 161: 100%|██████████| 283/283 [00:11<00:00, 24.85it/s]


Epoch 161  average loss: 0.8048


Epoch 162: 100%|██████████| 283/283 [00:11<00:00, 24.90it/s]


Epoch 162  average loss: 0.7963


Epoch 163: 100%|██████████| 283/283 [00:11<00:00, 24.81it/s]


Epoch 163  average loss: 0.8366


Epoch 164: 100%|██████████| 283/283 [00:12<00:00, 21.81it/s]


Epoch 164  average loss: 0.8337


Epoch 165: 100%|██████████| 283/283 [00:12<00:00, 22.26it/s]


Epoch 165  average loss: 0.7954


Epoch 166: 100%|██████████| 283/283 [00:13<00:00, 20.81it/s]


Epoch 166  average loss: 0.7873


Epoch 167: 100%|██████████| 283/283 [00:13<00:00, 21.14it/s]


Epoch 167  average loss: 0.8252


Epoch 168: 100%|██████████| 283/283 [00:13<00:00, 21.36it/s]


Epoch 168  average loss: 0.8315


Epoch 169: 100%|██████████| 283/283 [00:13<00:00, 20.25it/s]


Epoch 169  average loss: 0.7906


Epoch 170: 100%|██████████| 283/283 [00:13<00:00, 21.67it/s]


Epoch 170  average loss: 0.7759


Epoch 171: 100%|██████████| 283/283 [00:13<00:00, 20.77it/s]


Epoch 171  average loss: 0.7842


Epoch 172: 100%|██████████| 283/283 [00:12<00:00, 23.39it/s]


Epoch 172  average loss: 0.7987


Epoch 173: 100%|██████████| 283/283 [00:12<00:00, 22.57it/s]


Epoch 173  average loss: 0.7932


Epoch 174: 100%|██████████| 283/283 [00:12<00:00, 22.14it/s]


Epoch 174  average loss: 0.7893


Epoch 175: 100%|██████████| 283/283 [00:13<00:00, 20.44it/s]


Epoch 175  average loss: 0.7791


Epoch 176: 100%|██████████| 283/283 [00:13<00:00, 21.25it/s]


Epoch 176  average loss: 0.7761


Epoch 177: 100%|██████████| 283/283 [00:13<00:00, 21.54it/s]


Epoch 177  average loss: 0.7810


Epoch 178: 100%|██████████| 283/283 [00:11<00:00, 24.88it/s]


Epoch 178  average loss: 0.7931


Epoch 179: 100%|██████████| 283/283 [00:13<00:00, 21.17it/s]


Epoch 179  average loss: 0.7695


Epoch 180: 100%|██████████| 283/283 [00:13<00:00, 21.59it/s]


Epoch 180  average loss: 0.7469


Epoch 181: 100%|██████████| 283/283 [00:13<00:00, 21.37it/s]


Epoch 181  average loss: 0.7514


Epoch 182: 100%|██████████| 283/283 [00:13<00:00, 21.19it/s]


Epoch 182  average loss: 0.7756


Epoch 183: 100%|██████████| 283/283 [00:12<00:00, 23.22it/s]


Epoch 183  average loss: 0.7868


Epoch 184: 100%|██████████| 283/283 [00:11<00:00, 23.89it/s]


Epoch 184  average loss: 0.7797


Epoch 185: 100%|██████████| 283/283 [00:13<00:00, 21.42it/s]


Epoch 185  average loss: 0.7727


Epoch 186: 100%|██████████| 283/283 [00:13<00:00, 20.91it/s]


Epoch 186  average loss: 0.7976


Epoch 187: 100%|██████████| 283/283 [00:13<00:00, 21.75it/s]


Epoch 187  average loss: 0.7441


Epoch 188: 100%|██████████| 283/283 [00:12<00:00, 22.29it/s]


Epoch 188  average loss: 0.7435


Epoch 189: 100%|██████████| 283/283 [00:11<00:00, 24.27it/s]


Epoch 189  average loss: 0.7442


Epoch 190: 100%|██████████| 283/283 [00:13<00:00, 20.67it/s]


Epoch 190  average loss: 0.7426


Epoch 191: 100%|██████████| 283/283 [00:12<00:00, 21.91it/s]


Epoch 191  average loss: 0.8086


Epoch 192: 100%|██████████| 283/283 [00:13<00:00, 20.69it/s]


Epoch 192  average loss: 0.7635


Epoch 193: 100%|██████████| 283/283 [00:13<00:00, 21.69it/s]


Epoch 193  average loss: 0.7433


Epoch 194: 100%|██████████| 283/283 [00:12<00:00, 22.18it/s]


Epoch 194  average loss: 0.7432


Epoch 195: 100%|██████████| 283/283 [00:13<00:00, 20.49it/s]


Epoch 195  average loss: 0.7423


Epoch 196: 100%|██████████| 283/283 [00:12<00:00, 22.09it/s]


Epoch 196  average loss: 0.7405


Epoch 197: 100%|██████████| 283/283 [00:13<00:00, 21.17it/s]


Epoch 197  average loss: 0.7633


Epoch 198: 100%|██████████| 283/283 [00:14<00:00, 20.07it/s]


Epoch 198  average loss: 0.7286


Epoch 199: 100%|██████████| 283/283 [00:13<00:00, 21.18it/s]


Epoch 199  average loss: 0.7273


Epoch 200: 100%|██████████| 283/283 [00:12<00:00, 22.45it/s]


Epoch 200  average loss: 0.7351


Epoch 201: 100%|██████████| 283/283 [00:12<00:00, 23.14it/s]


Epoch 201  average loss: 0.7265


Epoch 202: 100%|██████████| 283/283 [00:13<00:00, 21.26it/s]


Epoch 202  average loss: 0.7606


Epoch 203: 100%|██████████| 283/283 [00:12<00:00, 23.01it/s]


Epoch 203  average loss: 0.7423


Epoch 204: 100%|██████████| 283/283 [00:13<00:00, 21.49it/s]


Epoch 204  average loss: 0.7403


Epoch 205: 100%|██████████| 283/283 [00:13<00:00, 20.25it/s]


Epoch 205  average loss: 0.7521


Epoch 206: 100%|██████████| 283/283 [00:12<00:00, 22.44it/s]


Epoch 206  average loss: 0.7244


Epoch 207: 100%|██████████| 283/283 [00:11<00:00, 24.18it/s]


Epoch 207  average loss: 0.7376


Epoch 208: 100%|██████████| 283/283 [00:12<00:00, 22.41it/s]


Epoch 208  average loss: 0.7388


Epoch 209: 100%|██████████| 283/283 [00:13<00:00, 20.40it/s]


Epoch 209  average loss: 0.7068


Epoch 210: 100%|██████████| 283/283 [00:13<00:00, 21.20it/s]


Epoch 210  average loss: 0.7122


Epoch 211: 100%|██████████| 283/283 [00:13<00:00, 20.48it/s]


Epoch 211  average loss: 0.7234


Epoch 212: 100%|██████████| 283/283 [00:14<00:00, 20.07it/s]


Epoch 212  average loss: 0.7069


Epoch 213: 100%|██████████| 283/283 [00:13<00:00, 21.26it/s]


Epoch 213  average loss: 0.7437


Epoch 214: 100%|██████████| 283/283 [00:12<00:00, 22.34it/s]


Epoch 214  average loss: 0.7246


Epoch 215: 100%|██████████| 283/283 [00:12<00:00, 22.63it/s]


Epoch 215  average loss: 0.7165


Epoch 216: 100%|██████████| 283/283 [00:13<00:00, 20.30it/s]


Epoch 216  average loss: 0.7612


Epoch 217: 100%|██████████| 283/283 [00:13<00:00, 20.30it/s]


Epoch 217  average loss: 0.7121


Epoch 218: 100%|██████████| 283/283 [00:12<00:00, 23.27it/s]


Epoch 218  average loss: 0.7173


Epoch 219: 100%|██████████| 283/283 [00:11<00:00, 24.47it/s]


Epoch 219  average loss: 0.7317


Epoch 220: 100%|██████████| 283/283 [00:11<00:00, 23.68it/s]


Epoch 220  average loss: 0.7084


Epoch 221: 100%|██████████| 283/283 [00:12<00:00, 23.14it/s]


Epoch 221  average loss: 0.7014


Epoch 222: 100%|██████████| 283/283 [00:13<00:00, 20.42it/s]


Epoch 222  average loss: 0.7298


Epoch 223: 100%|██████████| 283/283 [00:12<00:00, 21.95it/s]


Epoch 223  average loss: 0.6947


Epoch 224: 100%|██████████| 283/283 [00:13<00:00, 20.89it/s]


Epoch 224  average loss: 0.7542


Epoch 225: 100%|██████████| 283/283 [00:13<00:00, 20.90it/s]


Epoch 225  average loss: 0.6925


Epoch 226: 100%|██████████| 283/283 [00:12<00:00, 23.06it/s]


Epoch 226  average loss: 0.6949


Epoch 227: 100%|██████████| 283/283 [00:13<00:00, 20.55it/s]


Epoch 227  average loss: 0.7035


Epoch 228: 100%|██████████| 283/283 [00:14<00:00, 20.18it/s]


Epoch 228  average loss: 0.6920


Epoch 229: 100%|██████████| 283/283 [00:12<00:00, 23.07it/s]


Epoch 229  average loss: 0.7045


Epoch 230: 100%|██████████| 283/283 [00:12<00:00, 22.22it/s]


Epoch 230  average loss: 0.7242


Epoch 231: 100%|██████████| 283/283 [00:11<00:00, 24.42it/s]


Epoch 231  average loss: 0.7086


Epoch 232: 100%|██████████| 283/283 [00:12<00:00, 22.62it/s]


Epoch 232  average loss: 0.6961


Epoch 233: 100%|██████████| 283/283 [00:13<00:00, 21.25it/s]


Epoch 233  average loss: 0.7023


Epoch 234: 100%|██████████| 283/283 [00:12<00:00, 23.22it/s]


Epoch 234  average loss: 0.6994


Epoch 235: 100%|██████████| 283/283 [00:13<00:00, 21.50it/s]


Epoch 235  average loss: 0.6678


Epoch 236: 100%|██████████| 283/283 [00:12<00:00, 21.93it/s]


Epoch 236  average loss: 0.6871


Epoch 237: 100%|██████████| 283/283 [00:13<00:00, 21.26it/s]


Epoch 237  average loss: 0.6983


Epoch 238: 100%|██████████| 283/283 [00:13<00:00, 20.90it/s]


Epoch 238  average loss: 0.6938


Epoch 239: 100%|██████████| 283/283 [00:13<00:00, 20.72it/s]


Epoch 239  average loss: 0.7522


Epoch 240: 100%|██████████| 283/283 [00:13<00:00, 20.42it/s]


Epoch 240  average loss: 0.6824


Epoch 241: 100%|██████████| 283/283 [00:13<00:00, 20.35it/s]


Epoch 241  average loss: 0.7016


Epoch 242: 100%|██████████| 283/283 [00:12<00:00, 22.20it/s]


Epoch 242  average loss: 0.6815


Epoch 243: 100%|██████████| 283/283 [00:11<00:00, 24.21it/s]


Epoch 243  average loss: 0.6836


Epoch 244: 100%|██████████| 283/283 [00:12<00:00, 23.27it/s]


Epoch 244  average loss: 0.6697


Epoch 245: 100%|██████████| 283/283 [00:13<00:00, 21.70it/s]


Epoch 245  average loss: 0.6696


Epoch 246: 100%|██████████| 283/283 [00:12<00:00, 22.82it/s]


Epoch 246  average loss: 0.7055


Epoch 247: 100%|██████████| 283/283 [00:13<00:00, 20.46it/s]


Epoch 247  average loss: 0.6847


Epoch 248: 100%|██████████| 283/283 [00:13<00:00, 20.60it/s]


Epoch 248  average loss: 0.6898


Epoch 249: 100%|██████████| 283/283 [00:13<00:00, 21.54it/s]


Epoch 249  average loss: 0.7260


Epoch 250: 100%|██████████| 283/283 [00:13<00:00, 21.27it/s]


Epoch 250  average loss: 0.6687


Epoch 251: 100%|██████████| 283/283 [00:13<00:00, 21.44it/s]


Epoch 251  average loss: 0.6615


Epoch 252: 100%|██████████| 283/283 [00:14<00:00, 20.18it/s]


Epoch 252  average loss: 0.6575


Epoch 253: 100%|██████████| 283/283 [00:13<00:00, 20.97it/s]


Epoch 253  average loss: 0.6623


Epoch 254: 100%|██████████| 283/283 [00:13<00:00, 21.34it/s]


Epoch 254  average loss: 0.6976


Epoch 255: 100%|██████████| 283/283 [00:13<00:00, 21.20it/s]


Epoch 255  average loss: 0.6982


Epoch 256: 100%|██████████| 283/283 [00:12<00:00, 22.35it/s]


Epoch 256  average loss: 0.6572


Epoch 257: 100%|██████████| 283/283 [00:13<00:00, 21.74it/s]


Epoch 257  average loss: 0.6567


Epoch 258: 100%|██████████| 283/283 [00:12<00:00, 22.65it/s]


Epoch 258  average loss: 0.6565


Epoch 259: 100%|██████████| 283/283 [00:12<00:00, 21.80it/s]


Epoch 259  average loss: 0.6808


Epoch 260: 100%|██████████| 283/283 [00:13<00:00, 21.69it/s]


Epoch 260  average loss: 0.6767


Epoch 261: 100%|██████████| 283/283 [00:12<00:00, 22.48it/s]


Epoch 261  average loss: 0.6882


Epoch 262: 100%|██████████| 283/283 [00:11<00:00, 23.72it/s]


Epoch 262  average loss: 0.6739


Epoch 263: 100%|██████████| 283/283 [00:12<00:00, 21.94it/s]


Epoch 263  average loss: 0.6716


Epoch 264: 100%|██████████| 283/283 [00:13<00:00, 21.06it/s]


Epoch 264  average loss: 0.6535


Epoch 265: 100%|██████████| 283/283 [00:12<00:00, 23.16it/s]


Epoch 265  average loss: 0.6493


Epoch 266: 100%|██████████| 283/283 [00:13<00:00, 20.38it/s]


Epoch 266  average loss: 0.6716


Epoch 267: 100%|██████████| 283/283 [00:12<00:00, 23.25it/s]


Epoch 267  average loss: 0.6929


Epoch 268: 100%|██████████| 283/283 [00:12<00:00, 22.41it/s]


Epoch 268  average loss: 0.6813


Epoch 269: 100%|██████████| 283/283 [00:12<00:00, 22.95it/s]


Epoch 269  average loss: 0.6690


Epoch 270: 100%|██████████| 283/283 [00:13<00:00, 21.05it/s]


Epoch 270  average loss: 0.6725


Epoch 271: 100%|██████████| 283/283 [00:11<00:00, 24.51it/s]


Epoch 271  average loss: 0.7006


Epoch 272: 100%|██████████| 283/283 [00:13<00:00, 21.57it/s]


Epoch 272  average loss: 0.6452


Epoch 273: 100%|██████████| 283/283 [00:13<00:00, 20.28it/s]


Epoch 273  average loss: 0.6383


Epoch 274: 100%|██████████| 283/283 [00:12<00:00, 22.48it/s]


Epoch 274  average loss: 0.6596


Epoch 275: 100%|██████████| 283/283 [00:13<00:00, 21.09it/s]


Epoch 275  average loss: 0.6619


Epoch 276: 100%|██████████| 283/283 [00:12<00:00, 21.98it/s]


Epoch 276  average loss: 0.6569


Epoch 277: 100%|██████████| 283/283 [00:13<00:00, 21.59it/s]


Epoch 277  average loss: 0.6478


Epoch 278: 100%|██████████| 283/283 [00:13<00:00, 20.55it/s]


Epoch 278  average loss: 0.6545


Epoch 279: 100%|██████████| 283/283 [00:12<00:00, 22.65it/s]


Epoch 279  average loss: 0.6432


Epoch 280: 100%|██████████| 283/283 [00:13<00:00, 21.38it/s]


Epoch 280  average loss: 0.6497


Epoch 281: 100%|██████████| 283/283 [00:13<00:00, 20.51it/s]


Epoch 281  average loss: 0.6480


Epoch 282: 100%|██████████| 283/283 [00:12<00:00, 22.55it/s]


Epoch 282  average loss: 0.6484


Epoch 283: 100%|██████████| 283/283 [00:12<00:00, 23.00it/s]


Epoch 283  average loss: 0.6443


Epoch 284: 100%|██████████| 283/283 [00:13<00:00, 21.57it/s]


Epoch 284  average loss: 0.6375


Epoch 285: 100%|██████████| 283/283 [00:11<00:00, 24.53it/s]


Epoch 285  average loss: 0.6584


Epoch 286: 100%|██████████| 283/283 [00:12<00:00, 22.80it/s]


Epoch 286  average loss: 0.7418


Epoch 287: 100%|██████████| 283/283 [00:13<00:00, 20.24it/s]


Epoch 287  average loss: 0.6496


Epoch 288: 100%|██████████| 283/283 [00:12<00:00, 21.90it/s]


Epoch 288  average loss: 0.6315


Epoch 289: 100%|██████████| 283/283 [00:11<00:00, 24.63it/s]


Epoch 289  average loss: 0.6356


Epoch 290: 100%|██████████| 283/283 [00:12<00:00, 23.38it/s]


Epoch 290  average loss: 0.6316


Epoch 291: 100%|██████████| 283/283 [00:12<00:00, 21.82it/s]


Epoch 291  average loss: 0.6253


Epoch 292: 100%|██████████| 283/283 [00:12<00:00, 22.35it/s]


Epoch 292  average loss: 0.6235


Epoch 293: 100%|██████████| 283/283 [00:13<00:00, 20.99it/s]


Epoch 293  average loss: 0.6396


Epoch 294: 100%|██████████| 283/283 [00:12<00:00, 21.88it/s]


Epoch 294  average loss: 0.6601


Epoch 295: 100%|██████████| 283/283 [00:13<00:00, 21.09it/s]


Epoch 295  average loss: 0.7085


Epoch 296: 100%|██████████| 283/283 [00:13<00:00, 20.63it/s]


Epoch 296  average loss: 0.7117


Epoch 297: 100%|██████████| 283/283 [00:12<00:00, 22.25it/s]


Epoch 297  average loss: 0.6476


Epoch 298: 100%|██████████| 283/283 [00:12<00:00, 22.01it/s]


Epoch 298  average loss: 0.6261


Epoch 299: 100%|██████████| 283/283 [00:12<00:00, 21.93it/s]


Epoch 299  average loss: 0.6417


Epoch 300: 100%|██████████| 283/283 [00:12<00:00, 23.56it/s]

Epoch 300  average loss: 0.6307





In [None]:
# contrastive loss spconv

config_max_epochs = 200
config_device = 'cuda'
config_margin = 1.0
config_lambda_c = 0.01

train_dl = DataLoader(train_ds,batch_size=1,shuffle=True, collate_fn=lambda x: x[0])
# test_dl = DataLoader(test_ds,batch_size=1, collate_fn=lambda x: x[0])

model = SPConvModel(
    in_dim = 3+12,
    num_classes=89,
    pn_hidden=64,
    super_hidden=64,
    pn_dropout=0.1,
    super_dropout=0.2,
    edge_dim=32,
    super_num_conv=3
).to(device=config_device)

optim = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

for epoch in range(1,config_max_epochs+1):
    model.train()
    total_loss = 0.
    for d in tqdm(train_dl,total=len(train_ds), desc=f"Epoch {epoch:03d}"):
        pos         = d['pos'].to(device=config_device)
        point_feats = d['point_feats'].to(device=config_device)
        labels      = d['labels'].to(device=config_device)
        super_idx   = d['superpoint_idx'].to(device=config_device)
        edge_idx    = d['edge_index'].to(device=config_device)

        S = super_idx.amax() + 1 
        logits, z = model(pos, point_feats, super_idx, edge_idx)
        L_class = F.cross_entropy(logits, labels)

        # contrastive loss
        src, dst = edge_idx
        zi, zj = z[src], z[dst]
        dist2 = (zi - zj).pow(2).sum(dim=1)
        same = (labels[src] == labels[dst]).float() # [E]
        # pull
        loss_pos = same * dist2
        # push
        loss_neg = (1-same) * F.relu(config_margin - dist2)

        L_contrast = (loss_pos + loss_neg).mean()

        # total loss
        loss = L_class + config_lambda_c * L_contrast
        optim.zero_grad()
        loss.backward()
        optim.step()
        total_loss += loss.item()

    print(f"Epoch {epoch:02d}  average loss: {total_loss/len(train_dl):.4f}")

torch.save(model, './experiments/spconv_contrastive_L0-01.pth')

In [None]:
# contrastive loss spconv

config_max_epochs = 200
config_device = 'cuda'
config_margin = 1.0
config_lambda_c = 1.0

train_dl = DataLoader(train_ds,batch_size=1,shuffle=True, collate_fn=lambda x: x[0])
# test_dl = DataLoader(test_ds,batch_size=1, collate_fn=lambda x: x[0])

model = SPConvModel(
    in_dim = 3+12,
    num_classes=89,
    pn_hidden=64,
    super_hidden=64,
    pn_dropout=0.1,
    super_dropout=0.2,
    edge_dim=32,
    super_num_conv=3
).to(device=config_device)

optim = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

for epoch in range(1,config_max_epochs+1):
    model.train()
    total_loss = 0.
    for d in tqdm(train_dl,total=len(train_ds), desc=f"Epoch {epoch:03d}"):
        pos         = d['pos'].to(device=config_device)
        point_feats = d['point_feats'].to(device=config_device)
        labels      = d['labels'].to(device=config_device)
        super_idx   = d['superpoint_idx'].to(device=config_device)
        edge_idx    = d['edge_index'].to(device=config_device)

        S = super_idx.amax() + 1 
        logits, z = model(pos, point_feats, super_idx, edge_idx)
        L_class = F.cross_entropy(logits, labels)

        # contrastive loss
        src, dst = edge_idx
        zi, zj = z[src], z[dst]
        dist2 = (zi - zj).pow(2).sum(dim=1)
        same = (labels[src] == labels[dst]).float() # [E]
        # pull
        loss_pos = same * dist2
        # push
        loss_neg = (1-same) * F.relu(config_margin - dist2)

        L_contrast = (loss_pos + loss_neg).mean()

        # total loss
        loss = L_class + config_lambda_c * L_contrast
        optim.zero_grad()
        loss.backward()
        optim.step()
        total_loss += loss.item()

    print(f"Epoch {epoch:02d}  average loss: {total_loss/len(train_dl):.4f}")

torch.save(model, './experiments/spconv_contrastive_L1-00.pth')