In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.preprocessing import LabelEncoder, StandardScaler
import pandas as pd
import numpy as np
import os
from datetime import datetime


In [None]:

def load_and_validate_data(path):
    df = pd.read_csv(path)
    
    assert df['oil_property_value'].isna().sum() == 0, "NaN в целевом признаке"
    
    numeric_cols = ['mass_fraction', 'log_transformed', 'LogP', 'TPSA', 
                   'MolWt', 'Van_Der_Waals volumeFraction_non_rotatable_bounds',
                   'num_atoms', 'Degree_of_branching']
    
    for col in numeric_cols:
        if df[col].isna().sum() > 0:
            median_val = df[col].median()
            print(f"Заполнение {col} медианой: {median_val:.2f}")
            df[col] = df[col].fillna(median_val)
    
    assert not np.isinf(df[numeric_cols]).any().any(), "Обнаружены бесконечные значения"
    
    return df


In [None]:

def prepare_encoders(df):
    encoders = {
        'component': LabelEncoder().fit(df['component_name']),
        'type': LabelEncoder().fit(df['component_type_title']),
        'smiles': LabelEncoder().fit(df['smiles'])
    }
    
    scaler = StandardScaler().fit(df[numeric_cols])
    
    return encoders, scaler


In [None]:

class SafeBlendDataset(Dataset):
    def __init__(self, df, encoders, scaler):
        self.groups = df.groupby('blend_id')
        self.encoders = encoders
        self.scaler = scaler
        
    def __getitem__(self, idx):
        try:
            group = self.groups.get_group(self.groups.groups.keys()[idx])
            features = []
            
            for _, row in group.iterrows():
                component_idx = self.encoders['component'].transform([row['component_name']])[0]
                type_idx = self.encoders['type'].transform([row['component_type_title']])[0]
                smiles_idx = self.encoders['smiles'].transform([row['smiles']])[0]
                
                numerical = self.scaler.transform(row[numeric_cols].values.reshape(1, -1))[0]
                
                features.append(torch.cat([
                    torch.tensor([component_idx, type_idx, smiles_idx], dtype=torch.long),
                    torch.tensor(numerical, dtype=torch.float32)
                ]))
            
            return torch.stack(features), torch.tensor(group['oil_property_value'].iloc[0], dtype=torch.float32)
        
        except Exception as e:
            print(f"Ошибка в данных blend_id {group.name}: {str(e)}")
            return None


In [None]:
class SafeTransformer(nn.Module):
    def __init__(self, encoders, numeric_dim):
        super().__init__()
        self.init = lambda m: (nn.init.xavier_normal_(m.weight) if hasattr(m, 'weight') else None
        
        self.component_embed = nn.Embedding(len(encoders['component'].classes_), 16)
        self.type_embed = nn.Embedding(len(encoders['type'].classes_), 8)
        self.smiles_embed = nn.Embedding(len(encoders['smiles'].classes_), 32)
        self.apply(self.init)
        
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=16+8+32+numeric_dim,
                nhead=8,
                dim_feedforward=256,
                batch_first=True
            ),
            num_layers=4
        )
        
        self.regressor = nn.Sequential(
            nn.Linear(16+8+32+numeric_dim, 64),
            nn.Dropout(0.2),
            nn.LayerNorm(64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        
    def forward(self, x):
        if torch.isnan(x).any():
            raise ValueError("Обнаружен NaN во входных данных")
            
        categorical = x[:, :, :3].long()
        numerical = x[:, :, 3:]
        
        component_emb = self.component_embed(torch.clamp(categorical[:,:,0], 0, self.component_embed.num_embeddings-1))
        type_emb = self.type_embed(torch.clamp(categorical[:,:,1], 0, self.type_embed.num_embeddings-1))
        smiles_emb = self.smiles_embed(torch.clamp(categorical[:,:,2], 0, self.smiles_embed.num_embeddings-1))
        
        x = torch.cat([component_emb, type_emb, smiles_emb, numerical], dim=-1)
        
        x = self.transformer(x)
        
        return self.regressor(x.mean(dim=1))

In [None]:
def train_safe(model, dataloader, device, max_grad_norm=1.0):
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
    criterion = nn.HuberLoss()

    for epoch in range(100):
        model.train()
        total_loss = 0
        for batch in dataloader:
            if batch is None: continue  # Пропуск битых данных
            
            inputs, targets = map(lambda x: x.to(device), batch)
            
            try:
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs.squeeze(), targets)
                
                if torch.isnan(loss):
                    print(f"Обнаружен NaN loss на эпохе {epoch}")
                    continue
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                optimizer.step()
                total_loss += loss.item()
                
            except Exception as e:
                print(f"Ошибка обучения: {str(e)}")
                continue
        
        avg_loss = total_loss / len(dataloader)
        scheduler.step(avg_loss)
        
        if not torch.isnan(torch.tensor(avg_loss)):
            print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, LR={optimizer.param_groups[0]['lr']:.2e}")
            save_checkpoint(model, epoch, avg_loss)
            
        else:
            print("Обучение прервано из-за NaN")
            break


In [None]:
def predict_safe(model, sample, device):
    model.eval()
    with torch.no_grad():
        try:
            inputs = sample[0].unsqueeze(0).to(device)
            if torch.isnan(inputs).any():
                print("Обнаружен NaN во входных данных")
                return None
            return model(inputs).item()
        except Exception as e:
            print(f"Ошибка предсказания: {str(e)}")
            return None


In [None]:
def save_checkpoint(model, epoch, loss, save_dir="saves"):
    os.makedirs(save_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    checkpoint_path = os.path.join(
        save_dir,
        f"checkpoint_epoch_{epoch+1}_loss_{loss:.4f}_{timestamp}.pth"
    )
    
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'loss': loss,
        'timestamp': timestamp
    }, checkpoint_path)
    
    print(f"Checkpoint saved: {checkpoint_path}")

In [None]:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    df = load_and_validate_data("/content/my_data.csv")
    encoders, scaler = prepare_encoders(df)
    
    dataset = SafeBlendDataset(df, encoders, scaler)
    dataloader = DataLoader(
        dataset,
        batch_size=16,
        collate_fn=lambda x: tuple(filter(None, x))  # Фильтрация битых данных
    )
    
    model = SafeTransformer(encoders, len(numeric_cols)).to(device)
    
    train_safe(model, dataloader, device)