In [1]:
from tqdm import tqdm
import gc
import time
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import numpy as np
import sys
import warnings
warnings.filterwarnings('ignore')
sys.path.append("../utils/")
from utils.MLUtils import TrainerBase, cv_trainer, getSubsetIdx
from MyModels import GrindingPredictor
from MyDataset import MemoryDataset, get_dataset, get_collate_fn
from MyDataset import sampling_rate_ae,sampling_rate_vib, project_dir, project_name, data_dir, dataDir_ae, dataDir_vib, alphabet, allowed_input_types,logical_threads,physical_threads

class Trainer(TrainerBase):
    def _forward(self, batch):
        # Example implementation:
        inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        labels = batch['label'].to(self.device)
        
        outputs = self.model(inputs)
        
        return outputs, labels

class TrainingConfig:
    _instance = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._initialize()
        return cls._instance

    def _initialize(self):
        self.model = GrindingPredictor
        self.input_type = 'all'
        self.dataset_mode = 'classical'
        self.batch_size = 4
        self.num_workers = 4
        self.learning_rate = 0.001
        self.num_epochs = 10
        self.gpu = 'cpu'
        self.repeat = 1
        self.epochs = 2
        self.folds = 10
        self.model_name = 'SmartGrinding'
        self.threads = 8

args = TrainingConfig()
device = torch.device(args.gpu if torch.cuda.is_available() else "cpu")

In [2]:
cpus = [32, 16, 2, 1]
percentage=[0.6, 0.8, 0.90, 1]
dataset = get_dataset(input_type=args.input_type, dataset_mode=args.dataset_mode, cpus = cpus, percentage=percentage)
collate_fn = get_collate_fn(input_type=args.input_type)
model_name = args.model_name

Required components: {'all'}


In [3]:
loader = DataLoader(dataset, batch_size=3, shuffle=False, num_workers=0, collate_fn=collate_fn)
model = GrindingPredictor(input_type=args.input_type, interp=False)
model.to(device)

GrindingPredictor(
  (ae_spec_processor): SpectrogramProcessor(
    (conv): Sequential(
      (0): Conv2d(2, 8, kernel_size=(3, 3), stride=(2, 2))
      (1): ReLU()
      (2): AdaptiveAvgPool2d(output_size=(4, 4))
      (3): Flatten(start_dim=1, end_dim=-1)
      (4): Linear(in_features=128, out_features=32, bias=True)
    )
  )
  (vib_spec_processor): SpectrogramProcessor(
    (conv): Sequential(
      (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(2, 2))
      (1): ReLU()
      (2): AdaptiveAvgPool2d(output_size=(4, 4))
      (3): Flatten(start_dim=1, end_dim=-1)
      (4): Linear(in_features=128, out_features=32, bias=True)
    )
  )
  (ae_interpreter): FeatureInterpreter(
    (attention): Sequential(
      (0): Linear(in_features=36, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=36, bias=True)
      (3): Softmax(dim=-1)
    )
    (feature_processor): Linear(in_features=36, out_features=64, bias=True)
  )
  (vib_interpreter): FeatureInte

In [4]:
batch = next(iter(loader))
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
labels = batch['label'].to(device)

In [5]:
for input_type in allowed_input_types:
    print(f"Testing input type: {input_type}")
    model = GrindingPredictor(input_type=input_type, interp=False)
    model.to(device)
    outs = model.forward(inputs)
    print(f"Output shape for {input_type}: {outs.shape}")

Testing input type: ae_spec
Output shape for ae_spec: torch.Size([3, 1])
Testing input type: vib_spec
Output shape for vib_spec: torch.Size([3, 1])
Testing input type: ae_features
Output shape for ae_features: torch.Size([3, 1])
Testing input type: vib_features
Output shape for vib_features: torch.Size([3, 1])
Testing input type: ae_spec+ae_features
Output shape for ae_spec+ae_features: torch.Size([3, 1])
Testing input type: vib_spec+vib_features
Output shape for vib_spec+vib_features: torch.Size([3, 1])
Testing input type: ae_spec+ae_features+vib_spec+vib_features
Output shape for ae_spec+ae_features+vib_spec+vib_features: torch.Size([3, 1])
Testing input type: ae_features+pp
Output shape for ae_features+pp: torch.Size([3, 1])
Testing input type: vib_features+pp
Output shape for vib_features+pp: torch.Size([3, 1])
Testing input type: pp
Output shape for pp: torch.Size([3, 1])
Testing input type: all
Output shape for all: torch.Size([3, 1])
