In [1]:
from tqdm import tqdm
import gc
import time
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import numpy as np
import sys
import warnings
warnings.filterwarnings('ignore')
sys.path.append("../utils/")
from utils.MLUtils import TrainerBase, cv_trainer, getSubsetIdx
from MyModels import GrindingPredictor
from MyDataset import MemoryDataset, get_dataset, get_collate_fn
from MyDataset import sampling_rate_ae,sampling_rate_vib, project_dir, project_name, data_dir, dataDir_ae, dataDir_vib, alphabet, allowed_input_types,logical_threads,physical_threads

class Trainer(TrainerBase):
    def _forward(self, batch):
        # Example implementation:
        inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        labels = batch['label'].to(self.device)
        
        outputs = self.model(inputs)
        
        return outputs, labels

class TrainingConfig:
    _instance = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._initialize()
        return cls._instance

    def _initialize(self):
        self.model = GrindingPredictor
        self.input_type = 'all'
        self.dataset_mode = 'ram'
        self.batch_size = 4
        self.num_workers = 4
        self.learning_rate = 0.001
        self.num_epochs = 10
        self.gpu = 'cuda:0'
        self.repeat = 1
        self.epochs = 2
        self.folds = 10
        self.model_name = 'SmartGrinding'
        self.threads = 8

args = TrainingConfig()
device = torch.device(args.gpu if torch.cuda.is_available() else "cpu")

In [2]:
cpus = [32, 16, 2, 1]
percentage=[0.6, 0.8, 0.90, 1]
dataset = get_dataset(input_type=args.input_type, dataset_mode=args.dataset_mode, cpus = cpus, percentage=percentage)
collate_fn = get_collate_fn(input_type=args.input_type)
model_name = args.model_name

Generating segments for lenDataset = 320
------------------------------
Cpus: 32 | Segment 0-60%: Length = 192 (Indices 0-191)
Cpus: 16 | Segment 60-80%: Length = 64 (Indices 192-255)
Cpus: 2 | Segment 80-90%: Length = 32 (Indices 256-287)
Cpus: 1 | Segment 90-100%: Length = 32 (Indices 288-319)
------------------------------

Total number of indices across all segments: 320 (Should be 320)


Loading 0-60% data for 32 threads: 192it [03:08,  1.02it/s]


Loading threads (32) with remaining of data (192/320)


Loading 60-80% data for 16 threads: 64it [01:06,  1.05s/it]


Loading threads (16) with remaining of data (256/320)


Loading 80-90% data for 2 threads: 32it [01:32,  2.88s/it]


Loading threads (2) with remaining of data (288/320)


Loading 90-100% data for single thread: 100%|██████████| 32/32 [01:54<00:00,  3.58s/it]

Length of full_data: 320
Estimated size of full_data: 0.00 GB
Required components: {'all'}





In [3]:
loader = DataLoader(dataset, batch_size=5, shuffle=False, num_workers=0, collate_fn=collate_fn)
model = GrindingPredictor(input_type=args.input_type, interp=False)
model.to(device)

GrindingPredictor(
  (ae_spec_processor): SpectrogramProcessor(
    (conv): Sequential(
      (0): Conv2d(2, 8, kernel_size=(3, 3), stride=(2, 2))
      (1): ReLU()
      (2): AdaptiveAvgPool2d(output_size=(4, 4))
      (3): Flatten(start_dim=1, end_dim=-1)
      (4): Linear(in_features=128, out_features=32, bias=True)
    )
  )
  (vib_spec_processor): SpectrogramProcessor(
    (conv): Sequential(
      (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(2, 2))
      (1): ReLU()
      (2): AdaptiveAvgPool2d(output_size=(4, 4))
      (3): Flatten(start_dim=1, end_dim=-1)
      (4): Linear(in_features=128, out_features=32, bias=True)
    )
  )
  (ae_interpreter): FeatureInterpreter(
    (attention): Sequential(
      (0): Linear(in_features=36, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=36, bias=True)
      (3): Softmax(dim=-1)
    )
    (feature_processor): Linear(in_features=36, out_features=64, bias=True)
  )
  (vib_interpreter): FeatureInte

In [4]:
batch = next(iter(loader))
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
labels = batch['label'].to(device)

In [9]:
from MyModels import SpectrogramProcessor, FeatureInterpreter

class GrindingPredictor(nn.Module):
    def __init__(self,interp=False, input_type="all"):
        super().__init__()
        self.input_type = input_type
        # AE Pathway (2 spec channels + 4 time features)
        self.ae_spec_processor = SpectrogramProcessor(2, out_dim=32)
        self.interp = interp

        # Vib Pathway (3 spec channels + 4 time features)
        self.vib_spec_processor = SpectrogramProcessor(3, out_dim=32)

        self.ae_interpreter = FeatureInterpreter(spec_feat_dim=32, time_feat_dim=4)
        self.vib_interpreter = FeatureInterpreter(spec_feat_dim=32, time_feat_dim=4)

        # Physics Processor
        self.physics_encoder = nn.Sequential(
            nn.Linear(3, 64), nn.ReLU(), nn.LayerNorm(64)
        )

        self.regressor_input_dim = self._calculate_regressor_input_dim()

        # Final Fusion
        self.regressor = nn.Sequential(
            nn.Linear(self.regressor_input_dim, 128),  # 64(ae) + 64(vib) + 64(physics)
            nn.ReLU(),
            nn.Linear(128, 1),
        )

    def forward(self, batch):
        mode = self.input_type
        outputs = {}

        # AE Processing (if applicable)
        if 'ae' in mode or 'all' in mode:
            ae_spec = self.ae_spec_processor(batch["spec_ae"])  # [batch, seq_len, 32]
            ae_time = batch["features_ae"]  # [batch, seq_len, 4]
            ae_out, ae_attn = self.ae_interpreter(ae_spec, ae_time)
            outputs['ae_out'] = ae_out
            outputs['ae_attn'] = ae_attn

        # Vib Processing (if applicable)
        if 'vib' in mode or 'all' in mode:
            vib_spec = self.vib_spec_processor(batch["spec_vib"])  # [batch, seq_len, 32]
            vib_time = batch["features_vib"]  # [batch, seq_len, 4]
            vib_out, vib_attn = self.vib_interpreter(vib_spec, vib_time)
            outputs['vib_out'] = vib_out
            outputs['vib_attn'] = vib_attn

        # Physics Processing (if applicable)
        if 'all' in mode or 'pp' in mode:
            physics = self.physics_encoder(batch["features_pp"])
            outputs['physics'] = physics

        # Combine features based on mode
        if mode == 'ae_spec':
            combined = outputs['ae_out']
        elif mode == 'vib_spec':
            combined = outputs['vib_out']
        elif mode == 'ae_spec+ae_features':
            combined = torch.cat([outputs['ae_out'], ae_time.mean(dim=1)], dim=1)
        elif mode == 'vib_spec+vib_features':
            combined = torch.cat([outputs['vib_out'], vib_time.mean(dim=1)], dim=1)
        elif mode == 'ae_spec+ae_features+vib_spec+vib_features':
            combined = torch.cat([outputs['ae_out'], ae_time.mean(dim=1), outputs['vib_out'], vib_time.mean(dim=1)], dim=1)
        elif mode == 'pp':
            combined = outputs['physics']
        else:  # 'all'
            combined = torch.cat([outputs['ae_out'], outputs['vib_out'], outputs['physics']], dim=1)

        # Final prediction
        if self.interp:
            return self.regressor(combined), {"ae": outputs.get('ae_attn', None), "vib": outputs.get('vib_attn', None)}
        else:
            return self.regressor(combined)

    def _calculate_regressor_input_dim(self):
        """
        Calculate the input dimension for the regressor based on the input_type.
        """
        input_type = self.input_type
        if input_type == 'pp':
            return 64  # Only PP output
        if input_type == 'ae_spec':
            return 64  # Only AE output
        elif input_type == 'vib_spec':
            return 64  # Only Vib output
        elif input_type == 'ae_spec+ae_features':
            return 64 + 4  # AE output + AE time features
        elif input_type == 'vib_spec+vib_features':
            return 64 + 4  # Vib output + Vib time features
        elif input_type == 'ae_spec+ae_features+vib_spec+vib_features':
            return 64 + 4 + 64 + 4  # AE output + AE time features + Vib output + Vib time features
        else:  # 'all'
            return 64 + 64 + 64  # AE output + Vib output + Physics


    def _init_weights(self, m):
        """
        Initialize weights for all layers in the model.
        """
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.weight, 1.0)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1.0)
            nn.init.constant_(m.bias, 0)

    def initialize_weights(self):
        """
        Apply weight initialization to all submodules.
        """
        self.apply(self._init_weights)


In [12]:
allowed_input_types=[
    'pp',
    'ae_spec',
    'vib_spec',
    'ae_spec+ae_features',
    'vib_spec+vib_features',
    'ae_spec+ae_features+vib_spec+vib_features',
    'all']
for input_type in allowed_input_types:
    print(f"Testing input type: {input_type}")
    model = GrindingPredictor(input_type=input_type, interp=False)
    model.to(device)
    outs = model.forward(inputs)
    print(f"Output shape for {input_type}: {outs.shape}")

Testing input type: pp
Output shape for pp: torch.Size([5, 1])
Testing input type: ae_spec
Output shape for ae_spec: torch.Size([5, 1])
Testing input type: vib_spec
Output shape for vib_spec: torch.Size([5, 1])
Testing input type: ae_spec+ae_features
Output shape for ae_spec+ae_features: torch.Size([5, 1])
Testing input type: vib_spec+vib_features
Output shape for vib_spec+vib_features: torch.Size([5, 1])
Testing input type: ae_spec+ae_features+vib_spec+vib_features
Output shape for ae_spec+ae_features+vib_spec+vib_features: torch.Size([5, 1])
Testing input type: all
Output shape for all: torch.Size([5, 1])
