# 1. Data Augmentation

# 1.1 CGAN

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import matplotlib

plt.rcParams['font.sans-serif'] = ['Times New Roman']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.size'] = 12  
plt.rcParams['font.weight'] = 'bold' 
plt.rcParams['axes.titlesize'] = 14  
plt.rcParams['axes.labelsize'] = 12   
plt.rcParams['axes.titleweight'] = 'bold'  
plt.rcParams['axes.labelweight'] = 'bold'  
plt.rcParams['xtick.labelsize'] = 10   
plt.rcParams['ytick.labelsize'] = 10   
plt.rcParams['legend.fontsize'] = 10  
plt.rcParams['legend.title_fontsize'] = 12  

# Change the data path when test
df = pd.read_csv('E:/jupyter/lost_circulation/records/paper-bhyt/Diagnosis/data/Well_B_MLR.csv')

if 'LostCirculation' not in df.columns:
    raise ValueError("数据集必须包含 'LostCirculation' 列。")

X = df.drop('LostCirculation', axis=1)
y = df['LostCirculation']
X_min = X.min().values
X_max = X.max().values

decimal_places = [len(str(value).split('.')[1]) if '.' in str(value) else 0 for value in X.iloc[0]]

def build_generator(input_dim, output_dim, num_classes):
    model = tf.keras.Sequential([
        layers.Dense(512, activation='relu', input_dim=input_dim + num_classes),
        layers.Dense(256, activation='relu'),
        layers.Dense(output_dim, activation='sigmoid')
    ])
    return model

def build_discriminator(input_dim, num_classes):
    model = tf.keras.Sequential([
        layers.Dense(512, activation='leaky_relu', input_dim=input_dim + num_classes),
        layers.Dense(256, activation='leaky_relu'),
        layers.Dense(1, activation='sigmoid')
    ])
    return model

z_dim = 100 
X_dim = X.shape[1] 
num_classes = 2 
epochs = 1000
batch_size = 64
learning_rate = 0.0001 
patience = 25 

generator = build_generator(z_dim, X_dim, num_classes)
discriminator = build_discriminator(X_dim, num_classes)

generator_optimizer = tf.keras.optimizers.Adam(learning_rate)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate)
loss_fn = tf.keras.losses.BinaryCrossentropy()

D_losses = []
G_losses = []
best_G_loss = float('inf') 
best_D_loss = float('inf') 
stopping_counter = 0

for epoch in range(epochs):
    idx = np.random.randint(0, X[y == 1].shape[0], batch_size)
    real_samples = X[y == 1].iloc[idx].values
    real_labels = np.ones((batch_size, 1))

    z = np.random.randn(batch_size, z_dim)
    labels = np.zeros((batch_size, num_classes))
    labels[:, 1] = 1 
    noise_with_labels = np.concatenate((z, labels), axis=1) 
    fake_samples = generator(noise_with_labels)

    with tf.GradientTape() as tape:
        noise = tf.random.normal(shape=tf.shape(real_samples), mean=0.0, stddev=0.01)
        D_real_loss = loss_fn(real_labels, discriminator(tf.concat([real_samples + noise, np.ones((batch_size, num_classes))], axis=1)))
        D_fake_loss = loss_fn(np.zeros((batch_size, 1)), discriminator(tf.concat([fake_samples, np.zeros((batch_size, num_classes))], axis=1)))
        D_loss = D_real_loss + D_fake_loss
    D_gradients = tape.gradient(D_loss, discriminator.trainable_weights)
    discriminator_optimizer.apply_gradients(zip(D_gradients, discriminator.trainable_weights))

    z = np.random.randn(batch_size, z_dim)
    labels = np.zeros((batch_size, num_classes))
    labels[:, 1] = 1 
    noise_with_labels = np.concatenate((z, labels), axis=1)
    with tf.GradientTape() as tape:
        fake_samples = generator(noise_with_labels)
        G_loss = loss_fn(real_labels, discriminator(tf.concat([fake_samples, np.zeros((batch_size, num_classes))], axis=1)))
    G_gradients = tape.gradient(G_loss, generator.trainable_weights)
    generator_optimizer.apply_gradients(zip(G_gradients, generator.trainable_weights))

    D_losses.append(D_loss.numpy())
    G_losses.append(G_loss.numpy())

    if G_loss < best_G_loss:
        best_G_loss = G_loss
        stopping_counter = 0 
    else:
        stopping_counter += 1

    if D_loss < best_D_loss:
        best_D_loss = D_loss
        stopping_counter = 0
    else:
        stopping_counter += 1

    if stopping_counter >= patience:
        print(f"训练提前结束，达到最佳生成器损失: {best_G_loss} 和判别器损失: {best_D_loss}。")
        break

    if epoch % 1 == 0:
        print(f"Epoch {epoch}, D_loss: {D_loss}, G_loss: {G_loss}")

plt.figure(figsize=(10, 5))
plt.plot(D_losses, label='Discriminator loss')
plt.plot(G_losses, label='Generator loss')
plt.title('Changes of CGAN train loss')
plt.xlabel('Train iteration')
plt.ylabel('Loss')
plt.legend()
plt.tight_layout()
plt.savefig('E:/jupyter/lost_circulation/records/paper-bhyt/Diagnosis/picture/data_aug/cagn_loss.png', dpi=300)
plt.show()        
        
desired_minority_count = Counter(y)[0]
num_samples_to_generate = desired_minority_count - Counter(y)[1]

z = np.random.randn(num_samples_to_generate, z_dim)
labels = np.zeros((num_samples_to_generate, num_classes))
labels[:, 1] = 1 
noise_with_labels = np.concatenate((z, labels), axis=1)
generated_samples = generator(noise_with_labels)

generated_samples = generated_samples * (X_max - X_min) + X_min

generated_samples = generated_samples.numpy()
formatted_samples = np.array([np.round(col, decimals=dec) for col, dec in zip(generated_samples.T, decimal_places)]).T

df_original = pd.DataFrame(X, columns=X.columns)
integer_columns = ['WellName', 'WellType', 'Layer', 'Lithology', 'Formation']
for i, col in enumerate(df_original.columns):
    if col in integer_columns:
        formatted_samples[:, i] = formatted_samples[:, i].astype(int)

X_augmented = np.vstack([X, formatted_samples])
y_augmented = np.hstack([y, np.ones(formatted_samples.shape[0])])

data_before = pd.DataFrame({'Category': y, 'Type': ['Before augmentation'] * len(y)})
data_after = pd.DataFrame({'Category': y_augmented, 'Type': ['After augmentation'] * len(y_augmented)})
data = pd.concat([data_before, data_after])
fig, ax = plt.subplots(figsize=(8, 5))
sns.countplot(data=data, x='Category', hue='Type', ax=ax, palette=['#3498db', '#2ecc71'])
ax.set_xticks([0, 1])
ax.set_xticklabels(['0', '1'])

for p in ax.patches:
    height = int(p.get_height())
    if height > 0: 
        ax.annotate(f'{height}', 
                    (p.get_x() + p.get_width() / 2., height), 
                    ha='center', va='bottom', color='black')
ax.set_title("Category distribution before and after data augmentation")
ax.set_xlabel("Category")
ax.set_ylabel("Number of Samples")
ax.legend(loc='lower right')
plt.tight_layout()
plt.grid(False)
plt.savefig('E:/jupyter/lost_circulation/records/paper-bhyt/Diagnosis/picture/data_aug/category_distribution.png', dpi=300)
plt.show()

df_augmented = pd.DataFrame(X_augmented, columns=X.columns)
df_augmented['LostCirculation'] = y_augmented

output_path = 'E:/jupyter/lost_circulation/records/paper-bhyt/Diagnosis/data/Well_B_cgan_best.csv'
df_augmented.to_csv(output_path, index=False, encoding='utf-8-sig')
print(f"增强后的数据已保存至: {output_path}")

def plot_all_features_comparison_save(X, generated_samples, save_dir='E:/jupyter/lost_circulation/records/paper-bhyt/Diagnosis/picture/data_aug/'):
    import os
    
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    if not isinstance(X, pd.DataFrame) or not isinstance(generated_samples, pd.DataFrame):
        raise ValueError("输入的数据集必须都是 pandas DataFrame 类型")
    if set(X.columns) != set(generated_samples.columns):
        raise KeyError("两个数据集的特征列名不一致，无法进行对比。")
    available_styles = plt.style.available
    chosen_style = 'ggplot' if 'ggplot' in available_styles else 'default'
    plt.style.use(chosen_style)
    for feature in X.columns:
        plt.figure(figsize=(8, 6))
        sns.histplot(X[feature], color='green', label='Raw data', kde=True, alpha=0.5)
        sns.histplot(generated_samples[feature], color='pink', label='Augmented data', kde=True, alpha=0.5)
        plt.title(f'Comparison via histogram - {feature}')
        plt.xlabel(feature,color='black')
        plt.ylabel('Frequency',color='black')
        plt.legend()
        plt.tight_layout()
        save_path = os.path.join(save_dir, f'{feature}_comparison.png')
        plt.savefig(save_path,dpi=300)
        plt.show()
        plt.close() 
    print(f"所有特征对比图已保存至文件夹: {save_dir}")

plot_all_features_comparison_save(df, df_augmented)

generator.save('E:/jupyter/lost_circulation/records/paper-bhyt/Diagnosis/model/cgan_generator.h5')
discriminator.save('E:/jupyter/lost_circulation/records/paper-bhyt/Diagnosis/model/cgan_discriminator.h5')
print("生成器和判别器模型已保存。")

# 1.2 SMOTE

In [None]:
import pandas as pd
from imblearn.over_sampling import SMOTE
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
plt.rcParams['font.sans-serif'] = ['Times New Roman']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.size'] = 12 
plt.rcParams['font.weight'] = 'bold' 
plt.rcParams['axes.titlesize'] = 14  
plt.rcParams['axes.labelsize'] = 12   
plt.rcParams['axes.titleweight'] = 'bold'
plt.rcParams['axes.labelweight'] = 'bold'
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['legend.title_fontsize'] = 12

df = pd.read_csv('E:/jupyter/lost_circulation/records/paper-bhyt/Diagnosis/data/Well_B_MLR.csv')

if 'LostCirculation' not in df.columns:
    raise ValueError("数据集必须包含 'LostCirculation' 列。")
X = df.drop('LostCirculation', axis=1)
y = df['LostCirculation']

smote = SMOTE(sampling_strategy='auto', random_state=42)
X_resampled, y_resampled = smote.fit_resample(X, y)

data_before = pd.DataFrame({'Category': y, 'Type': ['Before augmentation'] * len(y)})
data_after = pd.DataFrame({'Category': y_resampled, 'Type': ['After augmentation'] * len(y_resampled)})
data = pd.concat([data_before, data_after])

fig, ax = plt.subplots(figsize=(8, 5))
sns.countplot(data=data, x='Category', hue='Type', ax=ax, palette=['#3498db', '#2ecc71'])
ax.set_xticks([0, 1])
ax.set_xticklabels(['0', '1'])

for p in ax.patches:
    height = int(p.get_height())
    if height > 0:
        ax.annotate(f'{height}', 
                    (p.get_x() + p.get_width() / 2., height), 
                    ha='center', va='bottom', color='black')

ax.set_title("Category distribution before and after SMOTE augmentation")
ax.set_xlabel("Category")
ax.set_ylabel("Number of Samples")
ax.legend(loc='lower right')

plt.tight_layout()
plt.grid(False)
plt.savefig('E:/jupyter/lost_circulation/records/paper-bhyt/Diagnosis/picture/data_aug/category_distribution_smote.png', dpi=300)
plt.show()

df_augmented = pd.DataFrame(X_resampled, columns=X.columns)
df_augmented['LostCirculation'] = y_resampled

output_path = 'E:/jupyter/lost_circulation/records/paper-bhyt/Diagnosis/data/Well_B_smote.csv'
df_augmented.to_csv(output_path, index=False, encoding='utf-8-sig')
print(f"增强后的数据已保存至: {output_path}")

def plot_all_features_comparison_save(X, generated_samples, save_dir='E:/jupyter/lost_circulation/records/paper-bhyt/Diagnosis/picture/data_aug/'):
    import os
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    if not isinstance(X, pd.DataFrame) or not isinstance(generated_samples, pd.DataFrame):
        raise ValueError("输入的数据集必须都是 pandas DataFrame 类型")

    if set(X.columns) != set(generated_samples.columns):
        raise KeyError("两个数据集的特征列名不一致，无法进行对比。")

    available_styles = plt.style.available
    chosen_style = 'ggplot' if 'ggplot' in available_styles else 'default'
    plt.style.use(chosen_style)

    for feature in X.columns:
        plt.figure(figsize=(8, 6)) 
        sns.histplot(X[feature], color='green', label='Raw data', kde=True, alpha=0.5)
        sns.histplot(generated_samples[feature], color='pink', label='Augmented data', kde=True, alpha=0.5)
        plt.title(f'Comparison via histogram - {feature}')
        plt.xlabel(feature,color='black')
        plt.ylabel('Frequency',color='black')
        plt.legend()
        plt.tight_layout()
        save_path = os.path.join(save_dir, f'{feature}_comparison_smote.png')
        plt.savefig(save_path,dpi=300)
        plt.show()
        plt.close()
    print(f"所有特征对比图已保存至文件夹: {save_dir}")

plot_all_features_comparison_save(df, df_augmented)
