In [6]:
from tqdm import tqdm
import gc
import time
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import numpy as np
import sys
import warnings
warnings.filterwarnings('ignore')
sys.path.append("../utils/")
from utils.MLUtils import TrainerBase, cv_trainer, getSubsetIdx
from MyModels import GrindingPredictor
from MyDataset import MemoryDataset, get_dataset, get_collate_fn
from MyDataset import sampling_rate_ae,sampling_rate_vib, project_dir, project_name, data_dir, dataDir_ae, dataDir_vib, alphabet, allowed_input_types,logical_threads,physical_threads

class Trainer(TrainerBase):
    def _forward(self, batch):
        # Example implementation:
        inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        labels = batch['label'].to(self.device)
        
        outputs = self.model(inputs)
        
        return outputs, labels

class TrainingConfig:
    _instance = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._initialize()
        return cls._instance

    def _initialize(self):
        self.model = GrindingPredictor
        self.input_type = 'all'
        self.dataset_mode = 'ram'
        self.batch_size = 4
        self.num_workers = 4
        self.learning_rate = 0.001
        self.num_epochs = 10
        self.gpu = 'cuda:0'
        self.repeat = 1
        self.epochs = 2
        self.folds = 10
        self.model_name = 'SmartGrinding'
        self.threads = 8

args = TrainingConfig()
device = torch.device(args.gpu if torch.cuda.is_available() else "cpu")

In [2]:
dataset = get_dataset(input_type=args.input_type, dataset_mode=args.dataset_mode, cpus = [logical_threads, physical_threads, 2, 1], percentage=[0.6, 0.8, 0.90, 1])
collate_fn = get_collate_fn(input_type=args.input_type)
model_name = args.model_name

Loading all spectrograms
Generating segments for lenDataset = 320
------------------------------
Cpus: 64 | Segment 0-60%: Length = 192 (Indices 0-191)
Cpus: 32 | Segment 60-80%: Length = 64 (Indices 192-255)
Cpus: 2 | Segment 80-90%: Length = 32 (Indices 256-287)
Cpus: 1 | Segment 90-100%: Length = 32 (Indices 288-319)
------------------------------

Total number of indices across all segments: 320 (Should be 320)


Loading 0-60% data for 64 threads: 192it [05:09,  1.61s/it]


Loading threads (64) with remaining of data (192/320)


Loading 60-80% data for 32 threads: 64it [02:12,  2.07s/it]


Loading threads (32) with remaining of data (256/320)


Loading 80-90% data for 2 threads: 32it [02:17,  4.30s/it]


Loading threads (2) with remaining of data (288/320)


Loading 90-100% data for single thread: 100%|██████████| 32/32 [03:56<00:00,  7.39s/it]

Length of full_data: 320
Estimated size of full_data: 0.00 GB
Required components: {'all'}





In [19]:
loader = DataLoader(dataset, batch_size=5, shuffle=False, num_workers=0, collate_fn=collate_fn)
model = GrindingPredictor(input_type=args.input_type, interp=False)
model.to(device)

GrindingPredictor(
  (ae_spec_processor): SpectrogramProcessor(
    (conv): Sequential(
      (0): Conv2d(2, 8, kernel_size=(3, 3), stride=(2, 2))
      (1): ReLU()
      (2): AdaptiveAvgPool2d(output_size=(4, 4))
      (3): Flatten(start_dim=1, end_dim=-1)
      (4): Linear(in_features=128, out_features=32, bias=True)
    )
  )
  (vib_spec_processor): SpectrogramProcessor(
    (conv): Sequential(
      (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(2, 2))
      (1): ReLU()
      (2): AdaptiveAvgPool2d(output_size=(4, 4))
      (3): Flatten(start_dim=1, end_dim=-1)
      (4): Linear(in_features=128, out_features=32, bias=True)
    )
  )
  (ae_interpreter): FeatureInterpreter(
    (attention): Sequential(
      (0): Linear(in_features=36, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=36, bias=True)
      (3): Softmax(dim=-1)
    )
    (feature_processor): Linear(in_features=36, out_features=64, bias=True)
  )
  (vib_interpreter): FeatureInte

In [10]:
batch = next(iter(loader))
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
labels = batch['label'].to(device)