In [1]:
import os
import random
import numpy as np
import pandas as pd
import os, json, joblib, re
import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import OneCycleLR

import polars as pl
from pathlib import Path
import warnings 
warnings.filterwarnings("ignore")

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.utils.class_weight import compute_class_weight

import gc  # garbage collection
import psutil
from scipy.spatial.transform import Rotation as R

In [2]:
# (Competition metric will only be imported when TRAINing)
TRAIN = True                     # ← set to True when you want to train

class config:
    AMP = False
    BATCH_SIZE_TRAIN = 8 #32
    BATCH_SIZE_VALID = 8 #32
    DEBUG = False
    EPOCHS = 4  #30
    FOLDS = 5
    GRADIENT_ACCUMULATION_STEPS = 1
    LEARNING_RATE = 1e-3
    MAX_GRAD_NORM = 1e7
    NUM_WORKERS = 0 # multiprocessing.cpu_count()
    PRINT_FREQ = 20
    SEED = 20
    WEIGHT_DECAY = 0.01
    PAD_PERCENTILE = 95
    SEQUENCE_LENGTH = 150

class paths:
    BASE_DIR = Path("C:/Users/konno/SynologyDrive/datasciense/projects_foler/1_kaggle/CMI/cmi-detect-behavior-with-sensor-data")
    
    OUTPUT_DIR = BASE_DIR / "output-02-wavenet"
    TEST_CSV = BASE_DIR / "test.csv"
    TEST_DEMOGRAPHICS = BASE_DIR / "test_demographics.csv"
    TRAIN_CSV = BASE_DIR / "train.csv"
    TRAIN_DEMOGRAPHICS = BASE_DIR / "train_demographics.csv"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("▶ imports ready · torch", torch.__version__, "device :", device)

▶ imports ready · torch 2.7.1+cpu device : cpu


In [3]:
def seed_everything(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [47]:

class MotionDataset(torch.utils.data.Dataset):
    def __init__(self, X, y, alpha=0.2):
        """
        X: np.array or torch.Tensor of shape (N, )
        y: np.array or torch.Tensor if shaoe (N, )
        alpha: Beta distribution parameter for mixup
        """
        self.X = torch.tensor(X, dtype=torch.float32) if isinstance(X, np.ndarray) else X
        self.y = torch.tensor(y, dtype=torch.float32) if isinstance(y, np.ndarray) else y
        self.alpha = alpha

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x1, y1 = self.X[idx], self.y[idx]
        
        # Create shuffle tensor
        shuffle_index = np.random.randint(0, len(self.X))
        x2, y2 = self.X[shuffle_index], self.y[shuffle_index]       

        # Mix
        weight = np.random.beta(self.alpha, self.alpha)
        x_mix = x1 * weight + x2 * (1 - weight)
        y_mix = y1 * weight + y2 * (1 - weight)

        return x_mix, y_mix
    
# train_dataset = MixupDataset(config, df_train, X_tr, y_tr, y_soft_tr)
# train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE_TRAIN, shuffle=True)
# val_dataset = CustomDataset(config, df_train, X_val, y_val, y_soft_val)
# val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE_VALID, shuffle=True)

def pad_or_truncate(seq, max_len, mode=TRAIN, pad_value=0.0, dtype=np.float32) -> np.ndarray:
    """
    Pads or truncates a sequence to a fixed length.

    Parameters:
    - seq: np.ndarray of shape (L, D)
    - max_len: int, desired sequence length
    - mode: bool, True = random pad, False = regular pad
    - pad_value: float or int, value to use for padding
    - dtype: np.dtype, dtype for the output array

    Returns:
    - np.ndarray of shape (max_len, D)
    """
    # print("sequence shape", seq.shape)
    L, D = seq.shape
    # print("mode = ", mode)

    if L > max_len:
        return seq[:max_len] # truncate if too long

    elif L < max_len:
        total_padding = max_len - L
        
        if mode:
            pad_start = np.random.randint(0, total_padding + 1)
            pad_end = total_padding - pad_start
            
        else:
            pad_start = 0
            pad_end = total_padding

        start_padding = np.full((pad_start, D), pad_value, dtype=dtype)
        end_padding = np.full((pad_end, D), pad_value, dtype=dtype)
        padded = np.vstack((start_padding, seq, end_padding))
        # print("padded shape", padded.shape)
        return padded

    else:
        return seq.astype(dtype)

In [5]:
def remove_gravity_from_acc(acc_data, rot_data):

    if isinstance(acc_data, pd.DataFrame):
        acc_values = acc_data[['acc_x', 'acc_y', 'acc_z']].values
    else:
        acc_values = acc_data

    if isinstance(rot_data, pd.DataFrame):
        quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
    else:
        quat_values = rot_data

    num_samples = acc_values.shape[0]
    linear_accel = np.zeros_like(acc_values)
    
    gravity_world = np.array([0, 0, 9.81])

    for i in range(num_samples):
        if np.all(np.isnan(quat_values[i])) or np.all(np.isclose(quat_values[i], 0)):
            linear_accel[i, :] = acc_values[i, :] 
            continue

        try:
            rotation = R.from_quat(quat_values[i])
            gravity_sensor_frame = rotation.apply(gravity_world, inverse=True)
            linear_accel[i, :] = acc_values[i, :] - gravity_sensor_frame
        except ValueError:
             linear_accel[i, :] = acc_values[i, :]
             
    return linear_accel

def calculate_angular_velocity_from_quat(rot_data, time_delta=1/200): # Assuming 200Hz sampling rate
    if isinstance(rot_data, pd.DataFrame):
        quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
    else:
        quat_values = rot_data

    num_samples = quat_values.shape[0]
    angular_vel = np.zeros((num_samples, 3))

    for i in range(num_samples - 1):
        q_t = quat_values[i]
        q_t_plus_dt = quat_values[i+1]

        if np.all(np.isnan(q_t)) or np.all(np.isclose(q_t, 0)) or \
           np.all(np.isnan(q_t_plus_dt)) or np.all(np.isclose(q_t_plus_dt, 0)):
            continue

        try:
            rot_t = R.from_quat(q_t)
            rot_t_plus_dt = R.from_quat(q_t_plus_dt)

            # Calculate the relative rotation
            delta_rot = rot_t.inv() * rot_t_plus_dt
            
            # Convert delta rotation to angular velocity vector
            # The rotation vector (Euler axis * angle) scaled by 1/dt
            # is a good approximation for small delta_rot
            angular_vel[i, :] = delta_rot.as_rotvec() / time_delta
        except ValueError:
            # If quaternion is invalid, angular velocity remains zero
            pass
            
    return angular_vel

def calculate_angular_distance(rot_data):
    if isinstance(rot_data, pd.DataFrame):
        quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
    else:
        quat_values = rot_data

    num_samples = quat_values.shape[0]
    angular_dist = np.zeros(num_samples)

    for i in range(num_samples - 1):
        q1 = quat_values[i]
        q2 = quat_values[i+1]

        if np.all(np.isnan(q1)) or np.all(np.isclose(q1, 0)) or \
           np.all(np.isnan(q2)) or np.all(np.isclose(q2, 0)):
            angular_dist[i] = 0 # Или np.nan, в зависимости от желаемого поведения
            continue
        try:
            # Преобразование кватернионов в объекты Rotation
            r1 = R.from_quat(q1)
            r2 = R.from_quat(q2)

            # Вычисление углового расстояния: 2 * arccos(|real(p * q*)|)
            # где p* - сопряженный кватернион q
            # В scipy.spatial.transform.Rotation, r1.inv() * r2 дает относительное вращение.
            # Угол этого относительного вращения - это и есть угловое расстояние.
            relative_rotation = r1.inv() * r2
            
            # Угол rotation vector соответствует угловому расстоянию
            # Норма rotation vector - это угол в радианах
            angle = np.linalg.norm(relative_rotation.as_rotvec())
            angular_dist[i] = angle
        except ValueError:
            angular_dist[i] = 0 # В случае недействительных кватернионов
            pass
            
    return angular_dist

def print_memory():
    process = psutil.Process()
    print(f"Memory Usage: {process.memory_info().rss / 1024**2:.2f} MB")

def parse_tof_column(col):
    # Match patterns like 'tof_1_v42' or 'tof_1_v42_norm'
    match = re.match(r"tof_(\d+)_v(\d+)", col)
    if match:
        sensor_num = int(match.group(1))
        pixel_num = int(match.group(2))
        return (sensor_num, pixel_num)
    else:
        return (float('inf'), float('inf'))  # put unmatchable columns at the end

In [6]:

class SEBlock1D(nn.Module):
    def __init__(self, channels, reduction=8):
        super().__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.fc2 = nn.Linear(channels // reduction, channels)

    def forward(self, x):
        # x shape: (batch, channels, time)
        b, c, _ = x.size()
        y = self.global_avg_pool(x).view(b, c)
        y = F.relu(self.fc1(y))
        y = torch.sigmoid(self.fc2(y)).view(b, c, 1)
        return x * y.expand_as(x)

class ResidualSEBlock1D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, pool_size=2, dropout=0.3):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, padding=kernel_size // 2, bias=False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size // 2, bias=False)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.se = SEBlock1D(out_channels)
        
        self.match_channels = None
        if in_channels != out_channels:
            self.match_channels = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1, padding=0, bias=False),
                nn.BatchNorm1d(out_channels)
            )
        
        self.pool = nn.MaxPool1d(pool_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        identity = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.se(out)

        if self.match_channels is not None:
            identity = self.match_channels(identity)
        
        out = F.relu(out + identity)
        out = self.pool(out)
        out = self.dropout(out)
        return out

class AttentionLayer(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.attn_fc = nn.Linear(input_dim, 1)

    def forward(self, x):
        # x shape: (batch, time, features)
        # Compute scores with tanh activation
        scores = torch.tanh(self.attn_fc(x))  # (batch, time, 1)
        scores = scores.squeeze(-1)           # (batch, time)

        # Softmax over time dimension to get weights
        weights = F.softmax(scores, dim=1)    # (batch, time)
        weights = weights.unsqueeze(-1)       # (batch, time, 1)

        # Weighted sum of input features over time
        context = (x * weights).sum(dim=1)    # (batch, features)
        return context

class TwoBranchGestureModel(nn.Module):
    def __init__(self, imu_dim, tof_dim, n_classes, wd=1e-4):
        super().__init__()
        self.imu_dim = imu_dim
        self.tof_dim = tof_dim
        
        # IMU deep branch
        self.imu_branch = nn.Sequential(
            ResidualSEBlock1D(imu_dim, 64, kernel_size=3, dropout=0.1),
            ResidualSEBlock1D(64, 128, kernel_size=5, dropout=0.1)
        )

        # TOF Lighter branch
        self.tof_branch = nn.Sequential(
            nn.Conv1d(tof_dim, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Dropout(0.2),
            nn.Conv1d(64, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm1d(128),
            nn.MaxPool1d(2),
            nn.Dropout(0.2),
        )

        self.lstm = nn.LSTM(256, 128, batch_first=True, bidirectional=True)
        self.gru = nn.GRU(256, 128, batch_first=True, bidirectional=True)

        # Gaussian noise (manual) and projection
        self.projection = nn.Sequential(
            nn.Dropout(0.09),
            nn.Linear(256, 16),
            nn.ELU()
        )

        self.pre_attn_dropout = nn.Dropout(0.4)
        self.attn = AttentionLayer(128*2 + 128*2 + 16)

        # Dense layer
        self.mlp = nn.Sequential(
            nn.Linear(528, 256, bias=False),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0,5),
            nn.Linear(256, 128, bias=False),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, n_classes),  # Softmax handled by loss (e.g., CrossEntropyLoss)
        )
    
    def forward(self, x):
        imu = x[:, :, :self.imu_dim].permute(0, 2, 1) #(B, imu_dim, T)
        tof = x[: ,:, self.imu_dim:].permute(0, 2, 1) #(B, tof_dim, T)

        imu_feat = self.imu_branch(imu)
        tof_feat = self.tof_branch(tof)

        imu_feat = imu_feat.permute(0, 2, 1)
        tof_feat = tof_feat.permute(0, 2, 1)

        merged = torch.cat([imu_feat, tof_feat], dim=-1)  #(B, T', 256)

        xa, _ = self.lstm(merged)
        xb, _ = self.gru(merged)
        xc = self.projection(merged)

        x_cat = torch.cat([xa, xb, xc], dim=-1)  #(B, T', 512)
        x_cat = self.pre_attn_dropout(x_cat)
        context = self.attn(x_cat)

        return self.mlp(context)

In [None]:
# ####  DATA SAMPLE to delete
# df_data = pd.read_csv(paths.TRAIN_CSV, nrows=5000)
# df = df_data.fillna(0)

# le = LabelEncoder()
# df['gesture_int'] = le.fit_transform(df['gesture'])
# np.save(paths.OUTPUT_DIR / "gesture_classes.npy", le.classes_)

# # print(df[['gesture_int', 'gesture', 'acc_x']].groupby('gesture').first())

# seq_gp = df.groupby('sequence_id') 
# seq_id, group_df = next(iter(seq_gp))
# # print("seq id", seq_id)
# # print("group df:", group_df[['gesture_int', 'gesture', 'acc_x']][:3])
# # print("df", df[['gesture_int', 'gesture', 'acc_x']][:3])

# all_steps_for_scaler_list = []
# X_list_unscaled, y_list_int_for_stratify, lens = [], [], [] 

# for seq_id, seq_df_orig in seq_gp:
#     seq_df = seq_df_orig.copy()

#     y_list_int_for_stratify.append(seq_df['gesture_int'].iloc[0])

# # print(y_list_int_for_stratify[:10])

# labels_tensor = torch.tensor(df['gesture_int'].values, dtype=torch.long)
# one_hot_tensor = F.one_hot(labels_tensor, num_classes=len(le.classes_))
# df['gesture_int_oh'] = one_hot_tensor.numpy().tolist()  # now each cell is a list

# subset_df = df[['gesture_int', 'gesture_int_oh']].head(500)
# subset_df.to_csv('gesture_with_onehot.csv', index=False)


In [53]:
### DATA CREATION and PRE PROCESSING

print("▶ TRAIN MODE – loading dataset …")

df_data = pd.read_csv(paths.TRAIN_CSV)
df_data = df_data.fillna(0)

train_dem_df = pd.read_csv(paths.TRAIN_DEMOGRAPHICS)
df = pd.merge(df_data.copy(), train_dem_df, on='subject', how='left')
print("merged df shape :", df.shape)

le = LabelEncoder()
df['gesture_int'] = le.fit_transform(df['gesture'])
np.save(paths.OUTPUT_DIR / "gesture_classes.npy", le.classes_)
gesture_classes = le.classes_

print_memory()

print(" 0/6 Calculating elbow_to_wrist_cm shoulder_to_wrist_cm adjustment ...")

df["acc_x_norm_ew"] = df["acc_x"] / df["elbow_to_wrist_cm"]
df["acc_y_norm_ew"] = df["acc_y"] / df["elbow_to_wrist_cm"]
df["acc_z_norm_ew"] = df["acc_z"] / df["elbow_to_wrist_cm"]

df["acc_x_norm_sw"] = df["acc_x"] / df["shoulder_to_wrist_cm"]
df["acc_y_norm_sw"] = df["acc_y"] / df["shoulder_to_wrist_cm"]
df["acc_z_norm_sw"] = df["acc_z"] / df["shoulder_to_wrist_cm"]

print(" 1/6 Calculating base engineered IMU features (magnitude, angle) ...")

df['acc_mag'] = np.sqrt(df['acc_x']**2 + df['acc_y']**2 + df['acc_z']**2)
df['rot_angle'] = 2* np.arccos(df['rot_w'].clip(-1, 1))

print(" 2/6 Calculating engineered IMU derivatives (jerk, angular velocity) for original acc_mag ...")

df['acc_mag_jerk'] = df.groupby('sequence_id')['acc_mag'].diff().fillna(0)
df['rot_angle_vel'] = df.groupby('sequence_id')['rot_angle'].diff().fillna(0)

print(" 3/6 Removing gravity and calculating linear acceleration features...")

linear_accel_list = []
for _, group in df.groupby('sequence_id'):
    acc_data_group = group[['acc_x', 'acc_y', 'acc_z']]
    rot_data_group = group[['rot_x', 'rot_y', 'rot_z', 'rot_w']]
    linear_accel_group = remove_gravity_from_acc(acc_data_group, rot_data_group)
    linear_accel_list.append(pd.DataFrame(linear_accel_group, columns=['linear_acc_x', 'linear_acc_y', 'linear_acc_z'], index=group.index))

df_linear_accel = pd.concat(linear_accel_list)
df = pd.concat([df, df_linear_accel], axis=1)
del df_linear_accel, linear_accel_list  # Memory Management
gc.collect()  # Memory Management

df['linear_acc_mag'] = np.sqrt(df['linear_acc_x']**2 + df['linear_acc_y']**2 + df['linear_acc_z']**2)
df['linear_acc_mag_jerk'] = df.groupby('sequence_id')['linear_acc_mag'].diff().fillna(0)

print(" 4/6 Calculating angular velocity from quaternion derivatives...")
angular_vel_list = []
for _, group in df.groupby('sequence_id'):
    rot_data_group = group[['rot_x', 'rot_y', 'rot_z', 'rot_w']]
    angular_vel_group = calculate_angular_velocity_from_quat(rot_data_group)
    angular_vel_list.append(pd.DataFrame(angular_vel_group, columns=['angular_vel_x', 'angular_vel_y', 'angular_vel_z'], index=group.index))

df_angular_vel = pd.concat(angular_vel_list)
df = pd.concat([df, df_angular_vel], axis=1)
del angular_vel_list, df_angular_vel # Memory Management
gc.collect() # Memory Management

print(" 5/6 Calculating angular distance between successive quaternions...")
angular_distance_list = []
for _, group in df.groupby('sequence_id'):
    rot_data_group = group[['rot_x', 'rot_y', 'rot_z', 'rot_w']]
    angular_dist_group = calculate_angular_distance(rot_data_group)
    angular_distance_list.append(pd.DataFrame(angular_dist_group, columns=['angular_distance'], index=group.index))

df_angular_distance = pd.concat(angular_distance_list)
df = pd.concat([df, df_angular_distance], axis=1)
del angular_distance_list, df_angular_distance # Memory Management
gc.collect() # Memory Management

print_memory()

meta_cols = { } # This was an empty dict in your provided code, keeping it as is.

print(" 6/6 Calculating imu_cols_base ...")
imu_cols_orig = ['acc_x', 'acc_y', 'acc_z',
            'rot_w', 'rot_x', 'rot_y', 'rot_z',
            'thm_1', 'thm_2', 'thm_3', 'thm_4', 'thm_5']

imu_cols_base = ['linear_acc_x', 'linear_acc_y', 'linear_acc_z']
imu_cols_base.extend([c for c in df.columns if c.startswith('rot_') and c not in ['rot_angle', 'rot_angle_vel']])

imu_engineered_features = [
    'acc_x_norm_ew', 'acc_y_norm_ew', 'acc_z_norm_ew',  # new from demographics
    'acc_x_norm_sw', 'acc_y_norm_sw', 'acc_z_norm_sw',  # new from demographics
    'acc_mag', 'rot_angle',
    'acc_mag_jerk', 'rot_angle_vel',
    'linear_acc_mag', 'linear_acc_mag_jerk',
    'angular_vel_x', 'angular_vel_y', 'angular_vel_z', # Existing new features
    'angular_distance' # Added new feature
]

dem_features = [
    'adult_child', 'age',
    'sex', 'handedness',
]

imu_cols = list(dict.fromkeys(imu_cols_orig + imu_cols_base + imu_engineered_features + dem_features))  # Remove dups

print("length of imu_cols :", len(imu_cols), "Obtaining tof columns ......")

tof_columns = [col for col in df.columns if col.startswith("tof_")]
tof_columns = sorted(tof_columns, key=parse_tof_column)

sequence_ids = df["sequence_id"].unique()

print("tof_columns length :", len(tof_columns))

del imu_cols_orig, imu_cols_base, imu_engineered_features, dem_features # Memory Management
gc.collect() # Memory Management

print("✅ Preprocessing done.")
print_memory()

# thm_cols_original = [c for c in df.columns if c.startswith('thm_')

▶ TRAIN MODE – loading dataset …
merged df shape : (574945, 348)
Memory Usage: 3106.66 MB
 0/6 Calculating elbow_to_wrist_cm shoulder_to_wrist_cm adjustment ...
 1/6 Calculating base engineered IMU features (magnitude, angle) ...
 2/6 Calculating engineered IMU derivatives (jerk, angular velocity) for original acc_mag ...
 3/6 Removing gravity and calculating linear acceleration features...
 4/6 Calculating angular velocity from quaternion derivatives...
 5/6 Calculating angular distance between successive quaternions...
Memory Usage: 3571.73 MB
 6/6 Calculating imu_cols_base ...
length of imu_cols : 35 Obtaining tof columns ......
tof_columns length : 320
✅ Preprocessing done.
Memory Usage: 3572.50 MB


In [54]:
### DATA CONFIGURATION

# Estimate the max length
sequence_lengths = df.groupby('sequence_id').size().values  # length of each sequence
SEQUENCE_LENGTH = int(np.percentile(sequence_lengths, 90))
print("SEQUENCE_LENGTH :", SEQUENCE_LENGTH)

X_2dim = df[imu_cols + tof_columns]
X_list = []
y_list = []

for seq_id, group in df.groupby('sequence_id', sort=False):
    X_seq = group[imu_cols + tof_columns].values.astype(np.float32)
    X_list.append(X_seq)
    y_list.append(group['gesture_int'].iloc[0])

X_padded = np.stack([pad_or_truncate(seq, SEQUENCE_LENGTH) for seq in X_list])
X = torch.tensor(X_padded, dtype=torch.float32)
y = F.one_hot(torch.tensor(np.array(y_list)), num_classes=len(le.classes_)).float()
print(f"X shape {X.shape} | y shape {y.shape}")

X_tr, X_val, y_tr, y_val = train_test_split(
    X, y,
    test_size=0.2,  # 20% validation
    random_state=42,
    stratify=df.groupby("sequence_id")["gesture"].first()  # keeps gesture label distribution balanced
)

## Sanity Check
for i, (seq_id, group) in enumerate(df.groupby('sequence_id', sort=False)):
    assert y_list[i] == group['gesture_int'].iloc[0]
print("length of imu", len(imu_cols))
print("length of tof", len(tof_columns))

train_dataset = MotionDataset(X_tr, y_tr, alpha=0.2)
val_dataset   = MotionDataset(X_val, y_val, alpha=0.2)

train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE_TRAIN, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_dataset, batch_size=config.BATCH_SIZE_VALID, shuffle=False, num_workers=0)

del X_list, y_list, X_padded

print_memory()

SEQUENCE_LENGTH : 103
X shape torch.Size([8151, 103, 355]) | y shape torch.Size([8151, 18])
length of imu 35
length of tof 320
Memory Usage: 4410.70 MB


In [55]:

model = TwoBranchGestureModel(
    imu_dim=len(imu_cols),         # channels per node (ToF + IMU)
    tof_dim=len(tof_columns),         # channels per node (ToF + IMU)
    n_classes=len(df["gesture"].unique()),  # e.g., 20
).to(device)

In [64]:
print("⏩ training started .....")

sequence_labels = df.groupby('sequence_id').first()['gesture_int'].values
print(np.unique(sequence_labels)) 
cw_vals = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(sequence_labels),
    y=torch.argmax(y, dim=1).numpy()
    )
print(cw_vals.shape)  # should be (num_classes,)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.LEARNING_RATE)
weights_tensor = torch.tensor(cw_vals, dtype=torch.float32).to(device)
loss_fn = nn.CrossEntropyLoss(weight=weights_tensor, label_smoothing=0.1)

print("▶️ Setting scheduler  .....")
steps = []
lrs = []
best_val_acc = 0
patience, patience_counter = 10, 0
EPOCHS = config.EPOCHS
scheduler = OneCycleLR(
    optimizer,
    max_lr=1e-3,
    epochs=config.EPOCHS,
    steps_per_epoch=len(train_loader),
    pct_start=0.0,
    anneal_strategy="cos",
    final_div_factor=100,
)

print("✅ Epoch starts .....")
import itertools

max_batches = 5

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    correct = 0         # <-- reset here
    total = 0           # <-- reset here
    # for batch_idx, batch in tqdm.tqdm(enumerate(itertools.islice(train_loader, max_batches))):
    for batch_idx, batch in tqdm.tqdm(enumerate(train_loader)):
        xb, yb = batch[0].to(device), batch[1].to(device)
        # if batch_idx == 0:
        #     print(f"Batch {batch_idx}: x_imu shape {x_imu.shape}, x_tof shape {x_tof.shape}")

        if batch_idx == 0:
            print(f"Batch {batch_idx}: x_imu shape {xb.shape}, x_tof shape {yb.shape}")
        optimizer.zero_grad()        
        logits = model(xb)
        # print("logits.shape:", logits.shape)
        # print("yb_indices.shape:", yb.shape)
        yb_indices = yb.argmax(dim=1)
        loss = loss_fn(logits, yb_indices)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # optional
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        lrs.append(scheduler.get_last_lr()[0])
        steps.append(epoch * config.BATCH_SIZE_TRAIN + batch_idx)

        
        logits_arg = logits.argmax(dim=1)
        correct += (logits_arg == yb_indices).sum().item()
        total += yb_indices.size(0)

    train_acc = correct / total
    print(f"Epoch {epoch} | Train Loss: {total_loss / len(train_loader):.4f} | Train Acc: {train_acc:.4f}")
    
    # Validation
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch in val_loader:
            xb, yb = batch[0].to(device), batch[1].to(device)

            x_preds = model(xb)
            logits = x_preds.argmax(dim=1)
            true_labels = yb.argmax(1) if yb.ndim > 1 else yb  #.argmax(1)  val_loader comes from a standard dataset with "y" as class index (long), you don’t need argmax.
            correct += (logits == true_labels).sum().item()
            total += true_labels.size(0)
    val_acc = correct / total
    print(f"Epoch {epoch} | Val Acc: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        torch.save(model.state_dict(), paths.OUTPUT_DIR / "best_model.pt")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

⏩ training started .....
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17]
(18,)
▶️ Setting scheduler  .....
✅ Epoch starts .....


1it [00:00,  7.60it/s]

Batch 0: x_imu shape torch.Size([8, 103, 355]), x_tof shape torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


3it [00:00,  7.21it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


5it [00:00,  7.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


7it [00:00,  7.96it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


9it [00:01,  8.24it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


11it [00:01,  7.87it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


14it [00:01,  8.32it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


15it [00:01,  8.11it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


17it [00:02,  8.45it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


21it [00:02,  9.95it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


23it [00:02,  9.53it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


25it [00:02, 10.02it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


29it [00:03, 10.26it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


31it [00:03, 10.61it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


35it [00:03,  9.82it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


36it [00:03,  9.74it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


40it [00:04, 10.56it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


42it [00:04, 10.28it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


44it [00:04, 10.03it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


48it [00:05,  9.78it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


49it [00:05,  9.70it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


53it [00:05, 10.14it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


55it [00:05,  9.99it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


59it [00:06, 10.22it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


61it [00:06, 10.00it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


63it [00:06, 10.19it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


65it [00:06, 10.64it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


67it [00:07,  8.59it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


71it [00:07,  9.68it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


73it [00:07, 10.08it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


76it [00:08,  9.42it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


77it [00:08,  9.36it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


79it [00:08,  7.57it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


81it [00:08,  7.21it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


83it [00:09,  7.21it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


85it [00:09,  5.55it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


87it [00:09,  5.28it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


88it [00:10,  5.52it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


90it [00:10,  6.46it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


93it [00:10,  6.85it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


95it [00:10,  7.70it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


97it [00:11,  8.18it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


99it [00:11,  8.09it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


101it [00:11,  8.00it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


102it [00:11,  8.03it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


104it [00:12,  7.45it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


106it [00:12,  7.42it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


108it [00:12,  7.11it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


110it [00:12,  6.83it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


112it [00:13,  6.51it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


114it [00:13,  6.77it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


116it [00:13,  6.70it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


118it [00:14,  6.46it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


120it [00:14,  6.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


122it [00:14,  6.84it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


124it [00:15,  7.01it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


126it [00:15,  6.69it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


128it [00:15,  6.79it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


130it [00:15,  6.47it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


132it [00:16,  6.65it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


134it [00:16,  7.16it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


136it [00:16,  7.42it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


138it [00:17,  7.40it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


141it [00:17,  7.95it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


142it [00:17,  7.85it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


144it [00:17,  7.14it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


146it [00:18,  7.69it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


148it [00:18,  8.39it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


151it [00:18,  8.49it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


154it [00:18,  8.61it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


155it [00:19,  8.51it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


158it [00:19,  8.66it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


160it [00:19,  8.88it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


161it [00:19,  8.68it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


164it [00:20,  8.88it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


166it [00:20,  8.88it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


167it [00:20,  8.54it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


169it [00:20,  8.84it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


171it [00:21,  7.45it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


174it [00:21,  8.31it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


175it [00:21,  8.62it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


177it [00:21,  7.28it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


180it [00:22,  8.31it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


182it [00:22,  8.77it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


183it [00:22,  8.51it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


185it [00:22,  8.61it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


188it [00:23,  8.86it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


190it [00:23,  8.92it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


192it [00:23,  9.14it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


194it [00:23,  8.57it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


195it [00:23,  7.84it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


197it [00:24,  8.39it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


199it [00:24,  8.32it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


202it [00:24,  8.94it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


203it [00:24,  8.89it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


205it [00:24,  8.69it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


207it [00:25,  8.57it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


209it [00:25,  8.43it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


212it [00:25,  7.89it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


214it [00:26,  8.45it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


215it [00:26,  8.07it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


218it [00:26,  8.59it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


220it [00:26,  8.77it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


222it [00:27,  8.73it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


224it [00:27,  8.23it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


227it [00:27,  8.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


229it [00:27,  8.95it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


231it [00:28,  9.13it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


232it [00:28,  8.61it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


234it [00:28,  8.34it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


236it [00:28,  7.22it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


238it [00:29,  6.61it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


240it [00:29,  7.23it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


242it [00:29,  7.40it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


244it [00:29,  7.15it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


246it [00:30,  6.84it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


248it [00:30,  6.48it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


250it [00:30,  6.45it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


252it [00:31,  6.56it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


254it [00:31,  6.57it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


257it [00:31,  7.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


258it [00:31,  7.94it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


260it [00:32,  7.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


262it [00:32,  7.45it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


264it [00:32,  7.38it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


266it [00:32,  7.93it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


268it [00:33,  7.93it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


270it [00:33,  7.52it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


273it [00:33,  8.21it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


274it [00:33,  7.79it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


276it [00:34,  7.52it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


278it [00:34,  7.88it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


280it [00:34,  8.01it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


282it [00:34,  8.06it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


284it [00:35,  7.92it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


286it [00:35,  7.83it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


288it [00:35,  7.46it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


290it [00:36,  7.47it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


292it [00:36,  7.03it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


294it [00:36,  6.56it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


296it [00:36,  7.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


298it [00:37,  7.61it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


300it [00:37,  7.83it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


302it [00:37,  7.06it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


304it [00:37,  7.46it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


306it [00:38,  7.08it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


308it [00:38,  6.32it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


310it [00:38,  6.52it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


312it [00:39,  6.17it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


314it [00:39,  6.53it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


316it [00:39,  6.49it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


318it [00:40,  6.25it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


320it [00:40,  6.31it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


322it [00:40,  6.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


324it [00:41,  6.29it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


326it [00:41,  6.05it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


328it [00:41,  6.39it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


330it [00:42,  6.41it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


332it [00:42,  6.52it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


334it [00:42,  6.37it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


336it [00:43,  6.37it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


338it [00:43,  6.34it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


340it [00:43,  6.40it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


342it [00:43,  6.62it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


344it [00:44,  5.76it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


346it [00:44,  5.75it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


348it [00:45,  6.02it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


350it [00:45,  6.29it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


352it [00:45,  6.21it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


354it [00:46,  6.12it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


356it [00:46,  5.98it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


358it [00:46,  5.74it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


360it [00:47,  5.65it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


361it [00:47,  5.28it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


363it [00:47,  6.44it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


365it [00:47,  7.12it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


367it [00:48,  6.91it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


369it [00:48,  7.14it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


371it [00:48,  7.19it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


373it [00:48,  7.38it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


375it [00:49,  7.00it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


377it [00:49,  7.52it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


379it [00:49,  7.16it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


382it [00:50,  8.19it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


383it [00:50,  8.34it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


385it [00:50,  8.20it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


388it [00:50,  8.67it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


389it [00:50,  7.91it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


391it [00:51,  8.11it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


394it [00:51,  8.38it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


395it [00:51,  8.68it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


398it [00:51,  8.69it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


399it [00:52,  8.80it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


402it [00:52,  9.25it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


406it [00:52, 10.28it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


408it [00:52, 10.79it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


412it [00:53, 10.65it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


414it [00:53, 10.42it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


416it [00:53,  9.29it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


419it [00:54,  9.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


423it [00:54,  9.54it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


424it [00:54,  9.61it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


428it [00:54, 10.54it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


430it [00:55, 10.28it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


434it [00:55, 10.05it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


437it [00:55,  9.37it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


438it [00:56,  7.57it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


440it [00:56,  7.18it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


442it [00:56,  7.82it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


446it [00:57,  9.21it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


448it [00:57, 10.01it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


452it [00:57, 10.97it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


454it [00:57,  8.78it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


457it [00:58,  8.35it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


458it [00:58,  7.64it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


460it [00:58,  7.43it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


463it [00:59,  8.23it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


465it [00:59,  9.14it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


468it [00:59,  9.85it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


472it [00:59,  9.99it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


474it [01:00,  9.57it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


478it [01:00, 10.17it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


480it [01:00, 10.52it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


484it [01:01, 10.34it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


486it [01:01,  9.85it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


488it [01:01, 10.23it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


492it [01:01,  8.86it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


493it [01:02,  9.07it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


495it [01:02,  9.18it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


497it [01:02,  8.38it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


501it [01:02,  9.67it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


504it [01:03, 10.14it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


506it [01:03, 10.48it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


508it [01:03,  9.75it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


512it [01:04, 10.40it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


514it [01:04, 10.70it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


518it [01:04, 10.52it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


520it [01:04, 10.81it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


524it [01:05, 10.86it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


526it [01:05, 11.10it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


528it [01:05, 11.00it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


532it [01:05,  9.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


535it [01:06,  9.43it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


537it [01:06,  9.93it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


540it [01:06,  9.87it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


542it [01:06,  9.67it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


546it [01:07, 10.87it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


548it [01:07, 11.13it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


552it [01:07, 10.51it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


554it [01:08, 10.44it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


556it [01:08, 10.08it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


560it [01:08, 10.11it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


564it [01:09,  8.99it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


566it [01:09,  9.34it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


568it [01:09,  9.28it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


570it [01:09,  9.79it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


572it [01:09, 10.00it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


575it [01:10,  9.91it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


579it [01:10, 10.22it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


581it [01:10, 10.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


583it [01:11,  9.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


587it [01:11, 10.18it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


589it [01:11, 10.27it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


593it [01:12, 10.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


595it [01:12,  9.54it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


599it [01:12, 10.17it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


601it [01:12,  9.87it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


603it [01:13,  9.96it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


607it [01:13,  9.87it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


608it [01:13,  9.89it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


612it [01:13, 10.26it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


614it [01:14, 10.03it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


616it [01:14, 10.13it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


618it [01:14,  9.60it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


620it [01:14,  8.72it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


622it [01:15,  9.29it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


626it [01:15,  9.21it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


628it [01:15,  9.74it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


630it [01:15, 10.09it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


634it [01:16, 10.17it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


636it [01:16,  9.54it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


638it [01:16,  8.91it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


641it [01:17,  9.25it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


643it [01:17,  9.87it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


647it [01:17, 10.32it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


649it [01:17, 10.81it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


653it [01:18, 11.29it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


655it [01:18, 11.44it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


659it [01:18, 10.83it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


661it [01:18, 10.79it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


665it [01:19, 11.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


667it [01:19, 10.94it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


671it [01:19, 11.49it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


673it [01:19, 11.54it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


677it [01:20, 11.69it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


679it [01:20, 11.79it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


683it [01:20, 11.64it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


685it [01:20, 11.89it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


689it [01:21, 12.16it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


691it [01:21, 11.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


695it [01:21, 11.28it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


697it [01:21, 11.07it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


701it [01:22, 11.47it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


703it [01:22, 11.69it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


707it [01:22, 11.55it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


709it [01:22, 11.75it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


713it [01:23, 11.74it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


717it [01:23, 12.16it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


719it [01:23, 12.32it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


721it [01:23, 11.86it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


725it [01:24, 11.97it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


727it [01:24, 12.01it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


731it [01:24, 10.51it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


733it [01:25, 10.51it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


736it [01:25,  9.64it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


738it [01:25, 10.19it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


742it [01:25, 10.74it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


744it [01:26, 11.17it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


748it [01:26, 10.96it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


750it [01:26, 11.27it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


754it [01:26, 11.68it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


756it [01:27, 11.48it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


760it [01:27, 12.01it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


762it [01:27, 12.12it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


766it [01:28, 11.46it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


768it [01:28, 11.73it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


772it [01:28, 11.24it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


774it [01:28, 11.19it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


778it [01:29, 10.68it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


780it [01:29, 11.08it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


784it [01:29, 10.06it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


786it [01:29, 10.62it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


790it [01:30, 11.38it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


792it [01:30, 11.20it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


796it [01:30, 10.60it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


798it [01:30, 10.72it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


802it [01:31, 11.37it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


804it [01:31, 11.68it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


808it [01:31, 11.77it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


810it [01:31, 12.05it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


814it [01:32, 11.77it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


815it [01:32,  8.82it/s]


Epoch 0 | Train Loss: 2.1240 | Train Acc: 0.4317
Epoch 0 | Val Acc: 0.4887


1it [00:00,  9.43it/s]

Batch 0: x_imu shape torch.Size([8, 103, 355]), x_tof shape torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


3it [00:00,  9.61it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


5it [00:00,  9.27it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


8it [00:00,  9.57it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


11it [00:01,  9.37it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


13it [00:01,  9.41it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


14it [00:01,  8.68it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


17it [00:01,  8.38it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


19it [00:02,  8.80it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


20it [00:02,  8.38it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


23it [00:02,  9.21it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


26it [00:02,  9.78it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


27it [00:02,  9.66it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


31it [00:03,  9.64it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


34it [00:03,  9.23it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


36it [00:03,  9.00it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


38it [00:04,  9.32it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


41it [00:04,  9.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


44it [00:04,  9.92it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


46it [00:04,  9.76it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


50it [00:05, 10.29it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


52it [00:05,  9.84it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


55it [00:05,  9.30it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


58it [00:06,  9.64it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


60it [00:06, 10.01it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


64it [00:06, 10.17it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


66it [00:06, 10.06it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


69it [00:07,  9.82it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


70it [00:07,  9.26it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


73it [00:07,  8.54it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


75it [00:08,  8.36it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


77it [00:08,  8.07it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


79it [00:08,  8.16it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


81it [00:08,  8.91it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


83it [00:08,  8.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


87it [00:09,  9.31it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


89it [00:09,  9.18it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


91it [00:09,  9.56it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


94it [00:10,  9.88it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


96it [00:10, 10.36it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


100it [00:10, 11.41it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


102it [00:10, 11.35it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


104it [00:11, 10.59it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


108it [00:11, 11.00it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


110it [00:11, 10.86it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


114it [00:11, 11.39it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


116it [00:12, 11.39it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


120it [00:12, 11.43it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


122it [00:12, 11.55it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


126it [00:12, 11.35it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


128it [00:13, 11.52it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


132it [00:13, 10.25it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


134it [00:13, 10.49it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


138it [00:14, 11.00it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


142it [00:14, 11.68it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


144it [00:14, 11.91it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


146it [00:14, 11.61it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


150it [00:15, 11.70it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


152it [00:15, 11.53it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


154it [00:15,  9.94it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


156it [00:15,  9.33it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


158it [00:16,  8.14it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


162it [00:16,  9.92it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


164it [00:16, 10.56it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


168it [00:16, 11.03it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


170it [00:17, 11.30it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


174it [00:17, 11.21it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


176it [00:17, 11.40it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


180it [00:17, 10.99it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


182it [00:18, 10.42it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


184it [00:18,  9.82it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


188it [00:18, 10.28it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


190it [00:18, 10.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


194it [00:19, 10.25it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


196it [00:19, 10.17it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


198it [00:19, 10.01it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


201it [00:20,  9.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


205it [00:20, 10.77it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


207it [00:20,  9.94it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


209it [00:20,  9.91it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


212it [00:21,  9.16it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


215it [00:21, 10.02it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


217it [00:21, 10.35it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


221it [00:22, 10.70it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


223it [00:22, 10.44it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


227it [00:22, 10.73it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


229it [00:22, 11.00it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


233it [00:23, 11.39it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


235it [00:23, 11.12it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


239it [00:23, 11.44it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


241it [00:23, 10.42it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


245it [00:24, 10.95it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


247it [00:24, 10.82it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


251it [00:24, 11.00it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


253it [00:24, 11.31it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


257it [00:25, 11.16it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


259it [00:25, 11.47it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


263it [00:25, 11.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


265it [00:25, 11.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


267it [00:26, 11.10it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


271it [00:26, 10.24it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


273it [00:26, 10.46it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


277it [00:27, 10.59it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


279it [00:27, 10.03it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


281it [00:27, 10.42it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


283it [00:27, 10.69it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


287it [00:28, 10.73it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


289it [00:28, 10.37it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


291it [00:28,  9.37it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


292it [00:28,  8.94it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


295it [00:29,  9.74it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


297it [00:29,  9.97it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


301it [00:29, 11.18it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


303it [00:29, 11.57it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


307it [00:30, 11.79it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


309it [00:30, 12.06it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


313it [00:30, 11.73it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


315it [00:30, 11.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


319it [00:31, 10.82it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


321it [00:31, 10.56it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


323it [00:31, 10.42it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


327it [00:31, 10.37it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


329it [00:32, 10.32it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


331it [00:32, 10.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


333it [00:32,  9.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


337it [00:32,  9.81it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


339it [00:33,  9.29it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


341it [00:33,  9.76it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


343it [00:33, 10.05it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


347it [00:33, 10.29it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


349it [00:34, 10.49it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


351it [00:34, 10.56it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


355it [00:34, 10.08it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


357it [00:34, 10.34it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


361it [00:35, 10.35it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


363it [00:35, 10.32it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


367it [00:35, 10.30it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


369it [00:36, 10.57it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


371it [00:36, 10.73it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


375it [00:36, 10.61it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


377it [00:36, 10.79it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


381it [00:37, 10.21it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


383it [00:37, 10.35it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


386it [00:37,  9.35it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


388it [00:37,  8.91it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


390it [00:38,  9.40it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


393it [00:38,  9.18it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


396it [00:38,  9.77it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


398it [00:39,  9.25it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


401it [00:39,  9.32it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


403it [00:39,  9.31it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


404it [00:39,  8.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


407it [00:40,  9.25it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


409it [00:40, 10.19it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


413it [00:40, 11.11it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


417it [00:40, 11.91it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


419it [00:41, 11.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


421it [00:41, 11.56it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


425it [00:41, 10.43it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


427it [00:41, 10.36it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


429it [00:42, 10.35it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


432it [00:42,  9.44it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


434it [00:42, 10.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


436it [00:42, 10.40it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


440it [00:43, 10.61it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


442it [00:43, 11.17it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


446it [00:43, 11.61it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


448it [00:43, 11.17it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


452it [00:44, 11.80it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


454it [00:44, 10.75it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


456it [00:44,  8.44it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


459it [00:44,  8.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


462it [00:45,  9.53it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


464it [00:45, 10.20it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


466it [00:45,  9.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


470it [00:46, 10.25it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


472it [00:46, 10.41it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


476it [00:46, 10.35it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


478it [00:46, 10.61it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


482it [00:47, 10.64it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


484it [00:47, 10.80it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


486it [00:47, 11.07it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


488it [00:47, 10.49it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


492it [00:48, 10.64it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


494it [00:48,  9.97it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


496it [00:48,  9.83it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


498it [00:48, 10.09it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


500it [00:48, 10.61it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


504it [00:49, 10.20it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


506it [00:49,  9.96it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


508it [00:49, 10.08it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


510it [00:49,  9.98it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


514it [00:50, 10.55it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


516it [00:50, 10.59it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


518it [00:50, 10.12it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


522it [00:51, 10.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


524it [00:51, 10.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


528it [00:51, 10.84it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


530it [00:51, 10.20it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


532it [00:52, 10.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


534it [00:52, 10.57it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


536it [00:52,  9.22it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


538it [00:52,  9.55it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


542it [00:53,  9.79it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


544it [00:53, 10.24it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


548it [00:53, 11.03it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


550it [00:53, 10.73it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


554it [00:54, 11.16it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


557it [00:54,  8.88it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


558it [00:54,  9.03it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


562it [00:55, 10.36it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


564it [00:55, 10.27it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


568it [00:55, 11.20it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


570it [00:55, 10.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


574it [00:56, 11.20it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


576it [00:56, 11.49it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


578it [00:56, 10.30it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


582it [00:56, 10.97it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


584it [00:57, 10.87it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


586it [00:57, 10.56it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


590it [00:57, 10.97it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


592it [00:57, 10.68it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


596it [00:58, 11.11it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


598it [00:58, 11.33it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


600it [00:58, 10.36it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


604it [00:58, 11.22it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


606it [00:59, 11.18it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


610it [00:59, 11.47it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


612it [00:59, 11.10it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


616it [00:59, 10.70it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


618it [01:00, 10.96it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


621it [01:00,  7.89it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


624it [01:01,  7.69it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


626it [01:01,  8.34it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


628it [01:01,  8.57it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


630it [01:01,  8.86it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


633it [01:02,  8.48it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


634it [01:02,  8.66it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


637it [01:02,  8.95it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


638it [01:02,  8.09it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


640it [01:02,  8.84it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


643it [01:03,  9.16it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


645it [01:03,  7.75it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


647it [01:03,  7.80it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


649it [01:04,  7.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


653it [01:04,  8.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


654it [01:04,  8.32it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


656it [01:04,  7.48it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


658it [01:05,  7.96it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


660it [01:05,  7.18it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


663it [01:05,  8.12it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


664it [01:05,  8.41it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


666it [01:06,  6.98it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


667it [01:06,  5.98it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


668it [01:06,  5.60it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


670it [01:07,  4.89it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


671it [01:07,  4.39it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


672it [01:07,  4.45it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


673it [01:07,  4.54it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


675it [01:08,  5.15it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


677it [01:08,  5.02it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


679it [01:08,  5.05it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


681it [01:09,  4.94it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


682it [01:09,  4.54it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


683it [01:09,  4.31it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


684it [01:10,  4.40it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


685it [01:10,  4.54it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


687it [01:10,  5.36it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


689it [01:10,  5.57it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


691it [01:11,  5.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


693it [01:11,  5.15it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


694it [01:11,  5.24it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


695it [01:12,  4.87it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


697it [01:12,  4.19it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


698it [01:12,  4.23it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


699it [01:13,  4.31it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


700it [01:13,  4.45it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


702it [01:13,  4.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


704it [01:14,  4.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


705it [01:14,  4.72it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


707it [01:14,  4.64it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


708it [01:15,  4.55it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


709it [01:15,  4.35it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


710it [01:15,  4.36it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


711it [01:15,  4.46it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


712it [01:15,  4.50it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


714it [01:16,  4.05it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


715it [01:16,  4.27it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


716it [01:16,  4.33it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


717it [01:17,  4.23it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


718it [01:17,  4.27it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


719it [01:17,  4.21it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


720it [01:17,  4.32it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


721it [01:18,  4.27it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


722it [01:18,  4.07it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


723it [01:18,  4.12it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


724it [01:18,  4.23it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


726it [01:19,  3.89it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


727it [01:19,  4.02it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


728it [01:19,  3.89it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


729it [01:20,  4.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


730it [01:20,  4.01it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


731it [01:20,  4.14it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


732it [01:20,  4.14it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


733it [01:21,  4.15it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


734it [01:21,  4.34it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


735it [01:21,  4.26it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


736it [01:21,  4.09it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


738it [01:22,  3.78it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


739it [01:22,  3.87it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


740it [01:22,  3.54it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


741it [01:23,  3.39it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


743it [01:23,  3.40it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


744it [01:24,  3.40it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


746it [01:24,  3.38it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


747it [01:25,  3.45it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


749it [01:25,  3.17it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


750it [01:26,  3.31it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


751it [01:26,  3.48it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


752it [01:26,  3.75it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


753it [01:26,  3.87it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


754it [01:26,  4.09it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


756it [01:27,  4.05it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


757it [01:27,  4.29it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


759it [01:27,  4.79it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


761it [01:28,  4.91it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


763it [01:28,  4.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


764it [01:29,  4.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


766it [01:29,  4.73it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


767it [01:29,  4.70it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


768it [01:29,  4.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


769it [01:30,  4.81it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


770it [01:30,  4.80it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


772it [01:30,  5.10it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


774it [01:31,  5.18it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


775it [01:31,  4.94it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


776it [01:31,  4.82it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


777it [01:31,  4.44it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


778it [01:31,  4.50it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


779it [01:32,  4.51it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


780it [01:32,  4.49it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


781it [01:32,  4.72it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


783it [01:33,  4.68it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


785it [01:33,  4.50it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


786it [01:33,  4.53it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


787it [01:33,  4.53it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


789it [01:34,  4.75it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


791it [01:34,  4.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


793it [01:35,  4.80it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


794it [01:35,  4.34it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


795it [01:35,  3.96it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


796it [01:35,  3.91it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


797it [01:36,  4.00it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


799it [01:36,  3.77it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


800it [01:37,  4.02it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


801it [01:37,  4.16it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


802it [01:37,  4.10it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


804it [01:37,  4.09it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


805it [01:38,  4.16it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


806it [01:38,  4.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


807it [01:38,  4.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


808it [01:38,  3.99it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


809it [01:39,  3.98it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


810it [01:39,  4.12it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


811it [01:39,  4.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


812it [01:39,  4.13it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


813it [01:40,  3.88it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


814it [01:40,  4.06it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


815it [01:40,  8.10it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
Epoch 1 | Train Loss: 1.9782 | Train Acc: 0.5084





Epoch 1 | Val Acc: 0.5530


0it [00:00, ?it/s]

Batch 0: x_imu shape torch.Size([8, 103, 355]), x_tof shape torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


2it [00:00,  4.14it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


3it [00:00,  4.39it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


4it [00:00,  4.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


5it [00:01,  4.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


7it [00:01,  4.39it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


8it [00:01,  4.12it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


9it [00:02,  4.20it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


10it [00:02,  4.30it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


11it [00:02,  4.40it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


13it [00:03,  3.98it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


14it [00:03,  4.09it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


15it [00:03,  4.10it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


16it [00:03,  4.23it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


17it [00:03,  4.39it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


19it [00:04,  4.73it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


21it [00:04,  4.81it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


23it [00:05,  4.66it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


24it [00:05,  4.48it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


25it [00:05,  4.07it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


26it [00:06,  3.53it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


28it [00:06,  3.37it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


29it [00:06,  3.69it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


31it [00:07,  4.30it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


33it [00:07,  4.91it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


35it [00:08,  5.27it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


37it [00:08,  5.07it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


39it [00:08,  5.32it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


41it [00:09,  5.18it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


43it [00:09,  4.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


45it [00:10,  3.70it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


46it [00:10,  3.78it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


47it [00:10,  4.09it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


48it [00:11,  4.01it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


49it [00:11,  4.08it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


50it [00:11,  4.38it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


52it [00:11,  4.88it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


54it [00:12,  5.10it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


56it [00:12,  5.16it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


58it [00:13,  4.67it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


59it [00:13,  4.48it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


60it [00:13,  4.83it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


62it [00:13,  4.82it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


64it [00:14,  5.21it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


66it [00:14,  5.01it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


68it [00:15,  5.18it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


70it [00:15,  5.08it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


72it [00:15,  5.40it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


74it [00:16,  5.43it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


76it [00:16,  5.24it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


78it [00:16,  5.18it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


80it [00:17,  5.23it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


82it [00:17,  5.16it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


84it [00:18,  5.39it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


85it [00:18,  5.19it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


87it [00:18,  5.38it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


89it [00:19,  5.69it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


91it [00:19,  5.86it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


93it [00:19,  6.06it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


95it [00:19,  6.28it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


97it [00:20,  6.25it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


99it [00:20,  6.39it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


101it [00:20,  5.83it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


103it [00:21,  5.68it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


105it [00:21,  5.51it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


107it [00:22,  5.61it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


109it [00:22,  5.33it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


111it [00:22,  5.73it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


113it [00:23,  5.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


115it [00:23,  6.22it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


117it [00:23,  6.22it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


119it [00:23,  6.85it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


121it [00:24,  7.33it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


123it [00:24,  6.39it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


125it [00:24,  6.49it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


127it [00:25,  6.73it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


129it [00:25,  6.82it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


131it [00:25,  6.59it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


133it [00:26,  6.03it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


135it [00:26,  5.56it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


137it [00:26,  5.64it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


139it [00:27,  5.36it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


141it [00:27,  5.82it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


143it [00:27,  5.53it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


145it [00:28,  5.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


147it [00:28,  5.72it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


149it [00:29,  5.77it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


150it [00:29,  5.52it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


152it [00:29,  6.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


154it [00:29,  4.96it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


155it [00:30,  4.89it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


156it [00:30,  4.99it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


158it [00:30,  5.12it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


160it [00:31,  5.29it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


162it [00:31,  5.52it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


164it [00:31,  5.92it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


166it [00:32,  5.51it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


168it [00:32,  5.66it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


170it [00:32,  5.93it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


172it [00:33,  6.35it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


174it [00:33,  6.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


176it [00:33,  6.02it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


178it [00:34,  5.45it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


180it [00:34,  6.03it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


182it [00:34,  5.88it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


184it [00:35,  6.21it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


186it [00:35,  6.02it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


188it [00:35,  6.38it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


190it [00:36,  5.87it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


191it [00:36,  5.77it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


193it [00:36,  5.55it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


195it [00:37,  5.82it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


197it [00:37,  6.10it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


199it [00:37,  6.05it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


201it [00:38,  6.13it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


203it [00:38,  6.17it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


205it [00:38,  6.46it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


207it [00:38,  6.10it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


209it [00:39,  6.02it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


211it [00:39,  5.76it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


213it [00:40,  5.68it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


215it [00:40,  5.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


217it [00:40,  5.57it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


218it [00:40,  5.39it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


220it [00:41,  5.36it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


222it [00:41,  5.26it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


224it [00:42,  5.43it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


226it [00:42,  5.70it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


228it [00:42,  6.16it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


230it [00:43,  5.92it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


232it [00:43,  5.67it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


234it [00:43,  5.80it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


236it [00:44,  5.77it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


238it [00:44,  5.62it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


240it [00:44,  5.67it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


242it [00:45,  5.68it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


244it [00:45,  5.79it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


246it [00:45,  5.73it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


248it [00:46,  5.62it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


250it [00:46,  5.56it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


252it [00:47,  5.57it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


254it [00:47,  5.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


256it [00:47,  5.49it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


258it [00:48,  5.73it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


260it [00:48,  6.25it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


262it [00:48,  5.98it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


264it [00:49,  6.87it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


266it [00:49,  6.76it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


268it [00:49,  6.81it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


270it [00:49,  6.22it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


272it [00:50,  6.34it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


274it [00:50,  5.55it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


276it [00:50,  5.94it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


278it [00:51,  6.00it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


280it [00:51,  6.55it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


282it [00:51,  7.08it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


284it [00:52,  7.34it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


286it [00:52,  7.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


288it [00:52,  7.23it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


290it [00:52,  6.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


292it [00:53,  6.87it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


294it [00:53,  6.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


296it [00:53,  6.14it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


298it [00:54,  6.36it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


300it [00:54,  6.81it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


302it [00:54,  5.61it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


304it [00:55,  5.76it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


306it [00:55,  6.28it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


308it [00:55,  6.77it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


310it [00:56,  6.31it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


312it [00:56,  5.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


314it [00:56,  6.19it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


316it [00:57,  6.53it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


318it [00:57,  6.13it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


320it [00:57,  6.56it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


322it [00:58,  6.45it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


324it [00:58,  6.46it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


326it [00:58,  6.68it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


328it [00:58,  6.92it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


330it [00:59,  6.48it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


332it [00:59,  6.26it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


334it [00:59,  6.54it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


336it [01:00,  6.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


338it [01:00,  7.05it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


340it [01:00,  7.06it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


342it [01:00,  7.10it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


344it [01:01,  7.25it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


346it [01:01,  6.42it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


348it [01:01,  6.49it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


350it [01:02,  6.56it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


352it [01:02,  6.60it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


354it [01:02,  6.77it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


356it [01:03,  6.48it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


358it [01:03,  6.56it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


360it [01:03,  6.38it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


362it [01:04,  5.50it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


364it [01:04,  6.05it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


366it [01:04,  6.31it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


368it [01:05,  6.56it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


370it [01:05,  5.86it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


372it [01:05,  5.89it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


374it [01:06,  6.02it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


376it [01:06,  6.27it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


378it [01:06,  6.36it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


380it [01:06,  6.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


382it [01:07,  6.43it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


383it [01:07,  6.02it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


385it [01:07,  4.94it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


386it [01:08,  4.79it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


388it [01:08,  5.55it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


390it [01:08,  5.78it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


392it [01:09,  5.92it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


394it [01:09,  6.13it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


396it [01:09,  6.10it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


398it [01:10,  5.72it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


400it [01:10,  5.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


402it [01:10,  5.74it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


403it [01:10,  5.61it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


405it [01:11,  3.07it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


406it [01:12,  2.13it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


408it [01:13,  2.16it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


409it [01:14,  2.50it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


410it [01:14,  2.86it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


411it [01:14,  3.22it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


412it [01:14,  3.62it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


414it [01:15,  4.10it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


415it [01:15,  4.44it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


417it [01:15,  4.69it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


419it [01:16,  4.88it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


421it [01:16,  4.76it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


423it [01:16,  5.02it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


425it [01:17,  5.59it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


427it [01:17,  6.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


429it [01:17,  5.73it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


431it [01:18,  6.14it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


433it [01:18,  6.78it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


435it [01:18,  6.55it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


437it [01:19,  6.85it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


439it [01:19,  6.92it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


441it [01:19,  6.95it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


443it [01:19,  6.49it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


445it [01:20,  5.14it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


447it [01:20,  4.85it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


448it [01:21,  4.68it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


449it [01:21,  4.35it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


450it [01:21,  4.49it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


452it [01:21,  4.93it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


454it [01:22,  5.41it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


456it [01:22,  5.80it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


458it [01:22,  6.33it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


460it [01:23,  6.65it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


462it [01:23,  6.38it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


464it [01:23,  5.88it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


465it [01:24,  5.10it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


466it [01:24,  4.85it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


468it [01:24,  4.68it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


469it [01:24,  4.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


471it [01:25,  4.47it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


472it [01:25,  4.83it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


474it [01:25,  5.57it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


476it [01:26,  5.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


478it [01:26,  5.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


480it [01:26,  5.94it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


482it [01:27,  6.47it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


484it [01:27,  6.45it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


486it [01:27,  6.08it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


488it [01:28,  5.55it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


490it [01:28,  5.76it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


492it [01:28,  5.64it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


494it [01:29,  5.36it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


496it [01:29,  5.89it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


498it [01:30,  5.17it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


500it [01:30,  5.83it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


502it [01:30,  6.44it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


504it [01:30,  6.60it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


506it [01:31,  6.77it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


508it [01:31,  6.21it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


510it [01:31,  6.47it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


512it [01:32,  5.88it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


514it [01:32,  6.11it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


516it [01:32,  6.39it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


518it [01:33,  6.08it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


520it [01:33,  5.84it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


522it [01:33,  6.00it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


524it [01:34,  5.02it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


525it [01:34,  4.94it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


527it [01:34,  4.68it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


528it [01:35,  4.94it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


530it [01:35,  4.89it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


531it [01:35,  4.67it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


532it [01:35,  4.89it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


534it [01:36,  4.78it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


536it [01:36,  5.54it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


538it [01:36,  6.00it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


540it [01:37,  6.03it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


542it [01:37,  5.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


544it [01:38,  5.55it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


546it [01:38,  4.92it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


548it [01:38,  5.46it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


550it [01:39,  6.00it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


552it [01:39,  6.53it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


554it [01:39,  6.24it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


556it [01:40,  6.21it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


558it [01:40,  6.28it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


560it [01:40,  6.00it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


562it [01:41,  5.59it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


564it [01:41,  4.28it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


566it [01:42,  4.76it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


568it [01:42,  5.25it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


570it [01:42,  4.95it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


572it [01:43,  4.92it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


574it [01:43,  5.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


576it [01:43,  5.54it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


578it [01:44,  5.92it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


580it [01:44,  5.83it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


582it [01:44,  5.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


584it [01:45,  5.82it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


586it [01:45,  5.78it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


588it [01:46,  5.52it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


589it [01:46,  4.97it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


590it [01:46,  4.49it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


592it [01:47,  3.75it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


593it [01:47,  3.53it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


595it [01:48,  3.52it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


596it [01:48,  3.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


598it [01:48,  4.56it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


600it [01:49,  4.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


602it [01:49,  5.40it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


604it [01:49,  5.85it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


606it [01:50,  5.48it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


607it [01:50,  5.39it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


609it [01:50,  5.16it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


611it [01:50,  5.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


613it [01:51,  6.12it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


615it [01:51,  6.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


617it [01:51,  6.88it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


619it [01:52,  6.49it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


621it [01:52,  6.13it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


622it [01:52,  5.73it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


624it [01:53,  6.19it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


626it [01:53,  6.33it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


628it [01:53,  5.03it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


630it [01:54,  5.45it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


632it [01:54,  6.01it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


634it [01:54,  6.51it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


636it [01:55,  5.70it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


638it [01:55,  4.51it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


639it [01:55,  4.57it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


641it [01:56,  4.95it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


643it [01:56,  5.64it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


645it [01:56,  5.76it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


647it [01:57,  6.23it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


649it [01:57,  6.54it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


651it [01:57,  6.98it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


653it [01:58,  5.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


654it [01:58,  5.65it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


656it [01:58,  5.50it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


658it [01:59,  4.89it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


659it [01:59,  4.91it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


661it [01:59,  5.59it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


663it [01:59,  5.80it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


665it [02:00,  6.01it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


667it [02:00,  6.19it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


669it [02:00,  5.48it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


670it [02:01,  5.35it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


672it [02:01,  5.19it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


674it [02:01,  5.25it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


675it [02:02,  5.27it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


676it [02:02,  4.67it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


677it [02:02,  3.29it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


678it [02:03,  2.34it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


680it [02:04,  2.64it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


681it [02:04,  3.09it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


683it [02:04,  3.88it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


685it [02:05,  4.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


687it [02:05,  4.85it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


688it [02:05,  5.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


690it [02:06,  5.45it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


692it [02:06,  5.50it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


694it [02:06,  5.92it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


696it [02:07,  5.48it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


698it [02:07,  5.75it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


700it [02:07,  5.82it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


702it [02:08,  5.41it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


704it [02:08,  4.79it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


706it [02:09,  4.33it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


708it [02:09,  5.10it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


710it [02:09,  5.50it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


712it [02:10,  5.80it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


714it [02:10,  5.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


716it [02:10,  5.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


718it [02:11,  6.09it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


720it [02:11,  6.17it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


722it [02:11,  6.23it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


724it [02:12,  5.98it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


726it [02:12,  4.97it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


728it [02:13,  3.97it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


729it [02:13,  4.09it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


730it [02:13,  4.23it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


731it [02:13,  4.28it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


733it [02:14,  4.94it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


735it [02:14,  5.55it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


737it [02:14,  5.96it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


739it [02:15,  5.91it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


741it [02:15,  6.09it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


743it [02:15,  6.20it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


745it [02:16,  6.30it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


747it [02:16,  6.35it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


749it [02:16,  5.67it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


751it [02:17,  5.38it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


753it [02:17,  5.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


755it [02:18,  5.65it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


757it [02:18,  5.88it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


759it [02:18,  5.62it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


761it [02:19,  5.65it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


762it [02:19,  5.98it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


764it [02:19,  4.68it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


766it [02:20,  5.08it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


768it [02:20,  4.86it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


770it [02:20,  5.53it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


772it [02:21,  5.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


774it [02:21,  6.12it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


776it [02:21,  5.68it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


778it [02:22,  6.21it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


780it [02:22,  6.08it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


782it [02:22,  6.16it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


784it [02:23,  4.40it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


785it [02:23,  4.24it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


786it [02:24,  3.66it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


788it [02:24,  3.62it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


790it [02:25,  4.29it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


792it [02:25,  4.60it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


794it [02:25,  5.18it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


796it [02:26,  4.42it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


797it [02:26,  4.65it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


799it [02:26,  5.31it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


801it [02:27,  5.57it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


803it [02:27,  6.19it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


805it [02:27,  6.28it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


807it [02:28,  6.49it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


809it [02:28,  5.60it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


811it [02:28,  5.80it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


813it [02:29,  5.57it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


814it [02:29,  5.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


815it [02:29,  5.45it/s]


Epoch 2 | Train Loss: 1.8639 | Train Acc: 0.5647
Epoch 2 | Val Acc: 0.5806


0it [00:00, ?it/s]

Batch 0: x_imu shape torch.Size([8, 103, 355]), x_tof shape torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


2it [00:00,  5.17it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


4it [00:00,  5.33it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


6it [00:01,  5.21it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


8it [00:01,  5.18it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


10it [00:01,  5.32it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


12it [00:02,  5.17it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


13it [00:02,  5.15it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


15it [00:02,  4.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


17it [00:03,  5.09it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


19it [00:03,  5.22it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


21it [00:04,  5.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


23it [00:04,  5.31it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


25it [00:04,  5.37it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


27it [00:05,  5.37it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


29it [00:05,  5.41it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


31it [00:05,  5.33it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


33it [00:06,  5.12it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


34it [00:06,  5.10it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


36it [00:07,  4.11it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


37it [00:07,  3.99it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


39it [00:07,  3.77it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


40it [00:08,  3.83it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


42it [00:09,  2.94it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


43it [00:09,  3.28it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


45it [00:09,  3.81it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


47it [00:10,  3.75it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


49it [00:10,  4.50it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


51it [00:11,  4.60it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


52it [00:11,  4.66it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


54it [00:11,  4.86it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


56it [00:12,  4.86it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


57it [00:12,  4.82it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


59it [00:12,  5.08it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


60it [00:12,  4.68it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


61it [00:13,  4.56it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


62it [00:13,  4.54it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


64it [00:13,  4.67it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


65it [00:14,  4.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


66it [00:14,  4.70it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


68it [00:14,  4.21it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


70it [00:15,  4.45it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


72it [00:15,  4.35it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


73it [00:15,  4.62it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


74it [00:16,  4.32it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


76it [00:16,  3.91it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


77it [00:16,  3.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


78it [00:17,  4.02it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


79it [00:17,  4.23it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


81it [00:17,  3.88it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


83it [00:18,  3.46it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


84it [00:18,  3.65it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


85it [00:19,  3.55it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


87it [00:19,  3.32it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


88it [00:20,  3.57it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


89it [00:20,  3.20it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


90it [00:20,  3.20it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


91it [00:21,  3.09it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


92it [00:21,  3.12it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


94it [00:21,  3.38it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


95it [00:22,  3.28it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


97it [00:22,  3.14it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


98it [00:23,  3.54it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


100it [00:23,  4.24it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


102it [00:23,  4.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


105it [00:24,  6.12it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


106it [00:24,  5.87it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


108it [00:25,  4.64it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


109it [00:25,  4.67it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


111it [00:25,  5.00it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


113it [00:26,  4.79it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


114it [00:26,  4.53it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


115it [00:26,  4.79it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


117it [00:26,  4.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


118it [00:27,  4.98it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


120it [00:27,  4.81it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


122it [00:27,  5.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


124it [00:28,  6.50it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


127it [00:28,  7.40it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


129it [00:28,  7.93it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


130it [00:28,  7.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


133it [00:29,  7.69it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


134it [00:29,  7.05it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


137it [00:29,  7.33it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


138it [00:29,  7.55it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


140it [00:30,  7.76it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


143it [00:30,  7.74it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


144it [00:30,  7.86it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


147it [00:31,  7.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


148it [00:31,  7.93it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


150it [00:31,  6.77it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


152it [00:31,  6.74it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


154it [00:32,  7.02it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


156it [00:32,  7.12it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


159it [00:32,  7.50it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


160it [00:32,  7.31it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


162it [00:33,  6.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


164it [00:33,  6.35it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


166it [00:33,  6.09it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


168it [00:34,  6.00it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


170it [00:34,  5.64it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


172it [00:35,  5.36it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


173it [00:35,  5.45it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


175it [00:35,  5.07it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


177it [00:36,  5.24it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


179it [00:36,  4.69it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


180it [00:36,  5.06it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


182it [00:37,  5.50it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


184it [00:37,  5.91it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


186it [00:37,  6.05it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


188it [00:37,  6.14it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


190it [00:38,  6.64it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


192it [00:38,  5.65it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


194it [00:38,  6.12it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


196it [00:39,  6.48it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


199it [00:39,  7.41it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


201it [00:39,  7.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


202it [00:40,  7.18it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


205it [00:40,  7.69it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


206it [00:40,  7.81it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


209it [00:40,  7.53it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


210it [00:41,  6.98it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


212it [00:41,  6.70it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


214it [00:41,  5.55it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


215it [00:42,  5.07it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


216it [00:42,  4.55it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


217it [00:42,  4.50it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


219it [00:43,  3.92it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


221it [00:43,  4.20it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


223it [00:44,  3.94it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


224it [00:44,  3.74it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


225it [00:44,  3.70it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


226it [00:45,  3.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


228it [00:45,  3.45it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


229it [00:45,  3.38it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


231it [00:46,  3.10it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


232it [00:46,  3.42it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


233it [00:47,  3.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


235it [00:47,  4.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


237it [00:47,  4.89it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


239it [00:48,  5.75it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


241it [00:48,  6.55it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


244it [00:48,  7.32it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


245it [00:48,  7.53it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


247it [00:49,  7.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


250it [00:49,  7.50it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


251it [00:49,  7.31it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


253it [00:49,  7.01it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


256it [00:50,  7.77it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


257it [00:50,  6.94it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


259it [00:50,  6.99it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


261it [00:51,  7.47it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


263it [00:51,  7.00it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


265it [00:51,  7.35it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


267it [00:51,  6.89it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


269it [00:52,  6.92it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


271it [00:52,  7.42it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


274it [00:52,  7.85it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


276it [00:53,  8.24it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


277it [00:53,  8.26it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


280it [00:53,  7.56it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


282it [00:53,  7.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


283it [00:54,  7.46it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


285it [00:54,  6.30it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


287it [00:54,  6.60it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


290it [00:55,  7.38it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


291it [00:55,  7.11it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


293it [00:55,  6.24it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


296it [00:55,  7.35it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


297it [00:56,  7.29it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


299it [00:56,  7.33it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


301it [00:56,  7.74it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


303it [00:56,  6.36it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


305it [00:57,  6.34it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


307it [00:57,  5.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


309it [00:58,  5.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


311it [00:58,  5.50it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


313it [00:58,  5.98it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


315it [00:59,  6.03it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


316it [00:59,  5.73it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


318it [00:59,  4.13it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


319it [01:00,  4.25it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


321it [01:00,  4.92it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


323it [01:00,  5.62it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


325it [01:01,  5.75it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


328it [01:01,  6.15it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


329it [01:01,  6.38it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


331it [01:02,  5.99it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


333it [01:02,  5.65it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


334it [01:02,  5.42it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


335it [01:02,  5.20it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


337it [01:03,  5.59it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


339it [01:03,  5.76it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


341it [01:03,  6.38it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


343it [01:04,  6.02it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


345it [01:04,  6.13it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


347it [01:04,  6.26it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


349it [01:05,  6.53it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


351it [01:05,  5.89it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


353it [01:05,  6.16it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


355it [01:05,  6.53it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


357it [01:06,  6.47it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


359it [01:06,  6.60it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


361it [01:06,  6.25it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


363it [01:07,  6.21it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


365it [01:07,  5.84it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


367it [01:07,  5.55it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


369it [01:08,  5.75it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


371it [01:08,  5.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


373it [01:08,  6.05it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


375it [01:09,  5.83it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


377it [01:09,  6.12it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


379it [01:09,  6.64it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


381it [01:10,  7.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


383it [01:10,  7.26it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


385it [01:10,  7.84it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


387it [01:10,  7.43it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


389it [01:11,  7.22it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


392it [01:11,  7.74it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


394it [01:11,  7.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


395it [01:12,  7.84it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


397it [01:12,  7.31it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


399it [01:12,  7.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


401it [01:12,  8.12it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


403it [01:13,  8.07it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


406it [01:13,  8.37it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


407it [01:13,  7.41it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


409it [01:13,  6.78it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


411it [01:14,  6.67it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


413it [01:14,  6.43it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


416it [01:14,  6.98it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


417it [01:15,  7.11it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


420it [01:15,  7.53it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


421it [01:15,  7.66it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


424it [01:15,  7.89it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


425it [01:16,  5.94it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


426it [01:16,  5.02it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


427it [01:16,  4.53it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


428it [01:16,  4.59it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


429it [01:17,  4.60it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


430it [01:17,  4.73it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


432it [01:17,  5.22it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


435it [01:18,  6.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


436it [01:18,  6.57it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


438it [01:18,  6.95it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


441it [01:18,  7.36it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


442it [01:19,  7.54it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


444it [01:19,  6.73it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


447it [01:19,  7.72it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


449it [01:20,  8.03it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


450it [01:20,  7.93it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


453it [01:20,  8.22it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


454it [01:20,  7.48it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


457it [01:21,  7.44it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


458it [01:21,  7.66it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


460it [01:21,  7.66it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


462it [01:21,  7.31it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


464it [01:22,  7.77it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


467it [01:22,  8.01it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


468it [01:22,  8.15it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


470it [01:22,  8.05it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


472it [01:23,  7.97it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


474it [01:23,  7.17it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


476it [01:23,  7.17it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


479it [01:23,  7.69it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


481it [01:24,  7.94it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


482it [01:24,  7.94it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


485it [01:24,  7.57it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


486it [01:24,  7.41it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


488it [01:25,  7.80it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


490it [01:25,  7.85it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


493it [01:25,  7.79it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


494it [01:25,  8.03it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


496it [01:26,  7.43it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


499it [01:26,  7.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


500it [01:26,  7.32it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


502it [01:27,  6.69it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


505it [01:27,  7.26it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


506it [01:27,  7.23it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


508it [01:27,  7.47it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


511it [01:28,  8.09it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


512it [01:28,  7.24it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


515it [01:28,  7.65it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


517it [01:29,  7.78it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


518it [01:29,  7.78it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


521it [01:29,  7.59it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


522it [01:29,  7.31it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


525it [01:30,  7.74it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


526it [01:30,  7.11it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


528it [01:30,  6.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


530it [01:30,  6.42it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


532it [01:31,  5.45it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


533it [01:31,  5.01it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


534it [01:31,  5.25it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


536it [01:32,  5.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


538it [01:32,  6.20it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


540it [01:32,  6.29it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


542it [01:32,  6.85it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


544it [01:33,  6.92it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


546it [01:33,  6.88it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


548it [01:33,  6.34it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


550it [01:34,  6.34it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


552it [01:34,  6.51it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


554it [01:34,  5.95it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


556it [01:35,  6.35it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


558it [01:35,  6.11it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


560it [01:35,  6.64it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


562it [01:36,  5.89it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


564it [01:36,  6.24it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


566it [01:36,  6.36it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


568it [01:36,  6.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


570it [01:37,  6.48it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


572it [01:37,  6.48it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


574it [01:37,  6.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


576it [01:38,  6.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


578it [01:38,  6.49it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


580it [01:38,  6.35it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


582it [01:39,  6.46it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


584it [01:39,  7.09it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


586it [01:39,  6.69it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


589it [01:40,  7.44it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


591it [01:40,  7.59it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


593it [01:40,  7.96it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


594it [01:40,  8.14it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


597it [01:41,  7.02it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


599it [01:41,  7.56it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


600it [01:41,  6.98it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


602it [01:41,  6.37it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


604it [01:42,  6.00it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


607it [01:42,  7.21it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


608it [01:42,  7.52it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


610it [01:43,  6.41it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


612it [01:43,  6.64it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


614it [01:43,  6.98it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


617it [01:44,  7.68it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


618it [01:44,  7.86it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


620it [01:44,  7.03it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


623it [01:44,  7.74it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


624it [01:45,  7.76it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


626it [01:45,  7.46it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


628it [01:45,  7.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


631it [01:46,  7.15it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


632it [01:46,  7.26it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


634it [01:46,  7.21it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


636it [01:46,  6.50it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


638it [01:47,  6.44it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


640it [01:47,  6.12it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


642it [01:47,  6.26it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


645it [01:48,  6.60it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


647it [01:48,  7.30it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


649it [01:48,  7.62it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


650it [01:48,  7.75it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


652it [01:49,  7.20it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


655it [01:49,  7.35it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


656it [01:49,  7.30it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


658it [01:50,  6.78it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


660it [01:50,  7.11it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


663it [01:50,  7.81it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


665it [01:50,  7.98it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


667it [01:51,  7.94it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


668it [01:51,  6.92it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


671it [01:51,  7.44it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


672it [01:51,  7.51it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


674it [01:52,  6.42it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


677it [01:52,  7.00it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


679it [01:52,  7.38it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


681it [01:53,  7.77it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


682it [01:53,  7.17it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


685it [01:53,  7.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


686it [01:53,  7.52it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


688it [01:54,  6.50it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


690it [01:54,  7.03it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


693it [01:54,  7.82it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


694it [01:54,  7.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


697it [01:55,  7.82it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


698it [01:55,  6.99it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


700it [01:55,  6.84it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


702it [01:56,  6.01it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


703it [01:56,  5.08it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


704it [01:56,  4.76it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


705it [01:56,  4.87it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


707it [01:57,  3.71it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


708it [01:57,  3.95it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


710it [01:58,  3.78it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


712it [01:58,  4.33it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


714it [01:59,  4.80it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


716it [01:59,  5.03it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


718it [01:59,  5.05it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


720it [02:00,  3.70it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


721it [02:00,  3.50it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


722it [02:01,  3.32it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


724it [02:01,  2.99it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


725it [02:02,  2.88it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


726it [02:02,  2.97it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


727it [02:03,  2.83it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


728it [02:03,  2.69it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


729it [02:03,  2.77it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


730it [02:04,  2.81it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


731it [02:04,  2.83it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


732it [02:04,  2.79it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


734it [02:05,  2.83it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


735it [02:05,  3.28it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


736it [02:06,  3.43it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


737it [02:06,  3.33it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


739it [02:06,  3.49it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


740it [02:07,  3.69it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


741it [02:07,  3.65it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


742it [02:07,  3.92it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


744it [02:07,  4.48it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


746it [02:08,  3.96it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


748it [02:08,  4.40it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


750it [02:09,  4.18it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


752it [02:09,  5.14it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


754it [02:10,  5.19it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


756it [02:10,  5.90it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


758it [02:10,  6.20it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


760it [02:11,  6.97it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


762it [02:11,  7.20it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


764it [02:11,  6.62it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


767it [02:11,  7.63it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


768it [02:12,  7.34it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


770it [02:12,  6.91it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


772it [02:12,  6.58it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


774it [02:13,  6.37it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


776it [02:13,  6.96it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


778it [02:13,  7.23it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


781it [02:13,  7.97it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


782it [02:14,  7.10it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


784it [02:14,  7.10it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


787it [02:14,  7.65it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


788it [02:14,  7.29it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


791it [02:15,  6.86it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


792it [02:15,  6.83it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


795it [02:15,  7.60it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


797it [02:16,  7.84it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


798it [02:16,  6.85it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


800it [02:16,  7.04it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


802it [02:16,  7.27it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


804it [02:17,  7.37it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


806it [02:17,  6.18it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


808it [02:17,  6.01it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


810it [02:18,  6.52it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


812it [02:18,  6.38it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


814it [02:18,  6.08it/s]

logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])
logits.shape: torch.Size([8, 18])
yb_indices.shape: torch.Size([8, 18])


815it [02:18,  5.86it/s]


Epoch 3 | Train Loss: 1.8188 | Train Acc: 0.5908
Epoch 3 | Val Acc: 0.5825
