In [None]:
# Constants
BATCH_SIZE = 8
CHUNK_SIZE = 8
START_IDX = 0
MAX_IMAGES_PER_SERIES = 128
CHECKPOINT_PATH = '/kaggle/working/model_checkpoint.pth'
IMG_SIZE = (224, 224)
NUM_EPOCHS = 5
LEARNING_RATE = 5e-7
LOSS_FUNCTION = 'binary_crossentropy'
METRICS = ['accuracy']

import os
import pandas as pd
import numpy as np
import pydicom
import psutil
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torchvision import models
from skimage.transform import resize
from sklearn.model_selection import train_test_split
import gc
import warnings
import os
import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.amp import autocast, GradScaler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def read_clinical_data(file_path):
    xls = pd.ExcelFile(file_path)
    sheet_names_lower = [sheet.lower() for sheet in xls.sheet_names]
    clinical_sheet = [sheet for sheet in sheet_names_lower if 'clinical' in sheet 
                      or 'клинические' in sheet 
                      or 'сlinical' in sheet]
    if not clinical_sheet:
        raise ValueError('No clinical sheet found in the XLSX file.')
    
    original_sheet_name = xls.sheet_names[sheet_names_lower.index(clinical_sheet[0])]
    df_clinical = pd.read_excel(xls, original_sheet_name)
    return df_clinical

def read_dicom_files(study_path):
    dicom_series = {}
    for series in os.listdir(study_path):
        series_path = os.path.join(study_path, series)
        if os.path.isdir(series_path):
            series_images = []
            for file in os.listdir(series_path):
                if file.endswith('.dcm'):
                    dicom_path = os.path.join(series_path, file)
                    dicom_image = pydicom.dcmread(dicom_path, force=True)
                    if 'SamplesPerPixel' not in dicom_image:
                        dicom_image.SamplesPerPixel = 1
                    if 'PhotometricInterpretation' not in dicom_image:
                        dicom_image.PhotometricInterpretation = 'MONOCHROME2'
                    series_images.append(dicom_image)
            if series_images:
                dicom_series[series] = series_images
    return dicom_series

def preprocess_images(dicom_images):
    preprocessed_images = []
    for dicom_image in dicom_images:
        image_array = dicom_image.pixel_array
        max_value = np.max(image_array)
        image_array = image_array / max_value if max_value > 0 else np.zeros_like(image_array)
        resized_image = resize(image_array, IMG_SIZE, anti_aliasing=True)
        rgb_image = np.stack((resized_image,) * 3, axis=-1) if len(resized_image.shape) == 2 else resized_image
        preprocessed_images.append(rgb_image)
    return np.array(preprocessed_images)

def process_batch(studies_folder, dataset_path, start_idx=START_IDX, batch_size=BATCH_SIZE, max_images_per_series=MAX_IMAGES_PER_SERIES):
    all_processed_series = {}
    all_clinical_data = pd.DataFrame()
    
    target_columns = ['epidural hemorrhage', 'subarachnoid hemorrhage', 'subdural hemorrhage', 
                      'intracerebral hemorrhage', 'multiple hemorrhages', 'skull fracture']
    
    if studies_folder != '200_300_studies':
        studies_folder_path = os.path.join(dataset_path, studies_folder, studies_folder)
        print(f"Processing folder {studies_folder}")
        xlsx_files = [f for f in os.listdir(studies_folder_path) if f.endswith('.xlsx')]
        
        if len(xlsx_files) == 1:
            clinical_data_path = os.path.join(studies_folder_path, xlsx_files[0])
            clinical_data = read_clinical_data(clinical_data_path)    
            clinical_data.columns = clinical_data.columns.str.lower()
            clinical_data = clinical_data[clinical_data['comment'] != 'Study without report']

            # Check if target columns are binary and drop non-binary rows
            for column in target_columns:
                if column in clinical_data.columns:
                    clinical_data = clinical_data[clinical_data[column].isin([0, 1])]
                else:
                    print(f"Warning: Column '{column}' not found in clinical data.")

            study_uids_in_folder = [uid for uid in os.listdir(studies_folder_path) if os.path.isdir(os.path.join(studies_folder_path, uid))]
            
            for study_uid in study_uids_in_folder[start_idx:start_idx+batch_size]:
                study_path = os.path.join(studies_folder_path, study_uid)
                
                dicom_series = read_dicom_files(study_path)
                
                for series_uid, series_images in dicom_series.items():
                    processed_images = preprocess_images(series_images)
                    
                    if len(processed_images) > 0:
                        if study_uid not in all_processed_series:
                            all_processed_series[study_uid] = {}
                        all_processed_series[study_uid][series_uid] = processed_images
                        
                        study_clinical_data = clinical_data[clinical_data['study_uid'] == study_uid]
                        if not study_clinical_data.empty:
                            all_clinical_data = pd.concat([all_clinical_data, study_clinical_data], ignore_index=True)
                            print(f"Processed {len(series_images)} DICOM images for study UID: {study_uid}, series UID: {series_uid}")
                        else:
                            print(f"Warning: No clinical data found for study UID: {study_uid}")
                    else:
                        print(f"No images processed for study UID: {study_uid}, series UID: {series_uid}")

        else:
            print(f"Error: Expected exactly one XLSX file in {studies_folder_path}")

    print("Number of processed studies:", len(all_processed_series))
    print("Number of rows in clinical data:", len(all_clinical_data))

    patient_data = {}

    for study_uid, series_dict in all_processed_series.items():
        study_clinical_data = all_clinical_data[all_clinical_data['study_uid'] == study_uid]
        if not study_clinical_data.empty:
            patient_data[study_uid] = {
                'images': series_dict,
                'labels': study_clinical_data.iloc[0][target_columns].to_dict()
            }
        else:
            print(f"Warning: No clinical data for study UID: {study_uid}")

    return patient_data, all_clinical_data, start_idx + len(all_processed_series)

from torchvision.models.resnet import ResNet50_Weights

def create_or_load_model(checkpoint_path):
    model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 6)  # Change to 6 outputs for 6 labels
    
    if os.path.exists(checkpoint_path):
        print(f"Attempting to load checkpoint from {checkpoint_path}")
        try:
            checkpoint = torch.load(checkpoint_path, weights_only=True)
            model.load_state_dict(checkpoint['model_state_dict'])
            print("Checkpoint loaded successfully")
        except (KeyError, RuntimeError) as e:
            print(f"Error loading checkpoint: {e}")
            print("Creating new model instead")
    else:
        print("No checkpoint found. Creating new model")
    
    return model

def custom_collate(batch):
    images, labels, patient_ids, series_uids = zip(*batch)
    
    # Pad images to the same size
    max_images = max(img.size(0) for img in images)
    padded_images = []
    for img in images:
        if img.size(0) < max_images:
            padding = torch.zeros(max_images - img.size(0), *img.size()[1:])
            img = torch.cat([img, padding], dim=0)
        padded_images.append(img)
    
    images = torch.stack(padded_images)
    labels = torch.stack(labels)
    
    return images, labels, patient_ids, series_uids

from torch.cuda.amp import autocast, GradScaler

def train_model_in_batches(model, studies_folders, dataset_path, batch_size=BATCH_SIZE, num_epochs=NUM_EPOCHS, accumulation_steps=4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    scaler = GradScaler()

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        for studies_folder in studies_folders:
            start_idx = 0
            while True:
                patient_data, clinical_data, next_start_idx = process_batch(studies_folder, dataset_path, start_idx, BATCH_SIZE, MAX_IMAGES_PER_SERIES)
                
                if not patient_data:
                    break  # No more data to process in this folder
                
                dataset = PatientDataset(patient_data, list(patient_data.keys()))
                dataloader = torch.utils.data.DataLoader(dataset, batch_size=CHUNK_SIZE, shuffle=True, collate_fn=custom_collate)
                
                for i, (images, labels, patient_ids, series_uids) in enumerate(dataloader):
                    print(f"Batch {i+1} - Images shape: {images.shape}, Labels shape: {labels.shape}")
                    images = images.to(device)
                    labels = labels.to(device)
                    
                    optimizer.zero_grad()
                    
                    for j in range(0, images.size(1), batch_size):
                        batch = images[:, j:j+batch_size].contiguous().view(-1, 3, IMG_SIZE[0], IMG_SIZE[1])
                        batch_labels = labels.repeat_interleave(batch_size, dim=0)
                        
                        print(f"  Sub-batch {j//batch_size + 1} - Batch shape: {batch.shape}, Batch labels shape: {batch_labels.shape}")

                        with autocast():
                            outputs = model(batch)
                            print(f"  Outputs shape: {outputs.shape}")
                            loss = criterion(outputs, batch_labels[:outputs.size(0)])
                        
                        scaler.scale(loss).backward()
                        
                        if (j + batch_size) % (batch_size * accumulation_steps) == 0:
                            scaler.step(optimizer)
                            scaler.update()
                            optimizer.zero_grad()

                    print(f"Batch {i+1} processed")
                
                start_idx = next_start_idx
                
                # Save checkpoint after each batch
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }, CHECKPOINT_PATH)
                print(f"Checkpoint saved to {CHECKPOINT_PATH}")
                
    return model

class PatientDataset(Dataset):
    def __init__(self, patient_data, patient_ids):
        self.patient_data = patient_data
        self.patient_ids = patient_ids
        self.series_data = []
        
        for patient_id in self.patient_ids:
            patient = self.patient_data[patient_id]
            labels = patient['labels']
            label_tensor = torch.tensor([labels[key] for key in sorted(labels.keys())]).float()
            
            for series_uid, series_images in patient['images'].items():
                series_tensor = torch.tensor(series_images).float().permute(0, 3, 1, 2)
                self.series_data.append((series_tensor, label_tensor, patient_id, series_uid))
    
    def __len__(self):
        return len(self.series_data)
    
    def __getitem__(self, idx):
        return self.series_data[idx]

def save_final_model(model, path=CHECKPOINT_PATH):
    torch.save(model.state_dict(), path)
    print(f"Final model saved to {path}")

dataset_path = '/kaggle/input/mosmeddata-ct-hemorrhage-type-viii/MosMedData-CT-HEMORRHAGE-type VIII/'

# Get a list of all studies folders
# studies_folders = [f for f in os.listdir(dataset_path) if f.endswith('_studies')]
studies_folders = ['400_500_studies']
print(studies_folders)

#### Main execution
model = create_or_load_model(CHECKPOINT_PATH)

# Train the model in batches
model = train_model_in_batches(model, studies_folders, dataset_path)

# Save the final model
save_final_model(model)

['400_500_studies']
No checkpoint found. Creating new model
Epoch 1/5
Processing folder 400_500_studies


  scaler = GradScaler()


Processed 851 DICOM images for study UID: 1.2.643.5.1.13.13.12.2.77.8252.03050206031205090212101415130402, series UID: 1.2.643.5.1.13.13.12.2.77.8252.09030205031314101014040300050508
Processed 551 DICOM images for study UID: 1.2.643.5.1.13.13.12.2.77.8252.12100603091102140112150804140706, series UID: 1.2.643.5.1.13.13.12.2.77.8252.05060713020912040111131210000303
Processed 769 DICOM images for study UID: 1.2.643.5.1.13.13.12.2.77.8252.06010500031303140403080114141012, series UID: 1.2.643.5.1.13.13.12.2.77.8252.09061213110611041510050311120313
Processed 251 DICOM images for study UID: 1.2.643.5.1.13.13.12.2.77.8252.14001507020807131402151406001108, series UID: 1.2.643.5.1.13.13.12.2.77.8252.00110701141401140814120108110706
Processed 83 DICOM images for study UID: 1.2.643.5.1.13.13.12.2.77.8252.08010811121300080404100303090501, series UID: 1.2.643.5.1.13.13.12.2.77.8252.02140815040601101009000405050805
Processed 329 DICOM images for study UID: 1.2.643.5.1.13.13.12.2.77.8252.0801081112130

  with autocast():


  Outputs shape: torch.Size([64, 6])
  Sub-batch 2 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Su

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 4 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 5 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 6 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 7 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 8 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 9 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 4 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 5 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 6 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 7 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 8 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 9 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Outputs shape: torch.Size([40, 6])
  Sub-batch 4 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 5 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 6 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 7 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 8 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 9 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 10 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  S

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 4 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 5 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 6 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 7 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 8 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 9 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 4 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 5 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 6 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 7 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 8 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 9 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Outputs shape: torch.Size([40, 6])
  Sub-batch 4 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 5 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 6 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 7 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 8 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 9 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 10 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  S

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 4 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 5 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 6 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 7 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 8 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 9 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 4 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 5 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 6 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 7 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 8 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 9 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Outputs shape: torch.Size([40, 6])
  Sub-batch 4 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 5 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 6 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 7 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 8 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 9 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 10 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  S

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 4 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 5 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 6 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 7 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 8 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 9 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 4 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 5 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 6 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 7 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 8 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 9 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Outputs shape: torch.Size([40, 6])
  Sub-batch 4 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 5 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 6 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 7 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 8 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 9 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  Sub-batch 10 - Batch shape: torch.Size([40, 3, 224, 224]), Batch labels shape: torch.Size([40, 6])
  Outputs shape: torch.Size([40, 6])
  S

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 4 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 5 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 6 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 7 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 8 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 9 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 4 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 5 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 6 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 7 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 8 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 9 - Batch shape: torch.Size([64, 3, 224, 224]), Batch labels shape: torch.Size([64, 6])
  Outputs shape: torch.Size([64, 6])
  Sub-batch 10 - Batch shape: torch.Size(

  with autocast():


  Sub-batch 3 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 4 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 5 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 6 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 7 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 8 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 9 - Batch shape: torch.Size([56, 3, 224, 224]), Batch labels shape: torch.Size([56, 6])
  Outputs shape: torch.Size([56, 6])
  Sub-batch 10 - Batch shape: torch.Size(