**Analyzing health states durations**


In [None]:
import pandas as pd

# Load the data
df = pd.read_csv(r"C:\Users\lamia\Downloads\filtered_dataset1_more_than_18_obs.csv")

# Clean column names
df.columns = df.columns.str.strip()

# Convert the 'date' column to datetime type
df['date'] = pd.to_datetime(df['date'])

# Sort the data by cow and date
df = df.sort_values(by=['cow', 'date'])

# Function to get durations for a specific health state
def get_health_state_durations(df, health_state_column):
    durations = []

    # For each cow
    for cow_id, group in df.groupby('cow'):
        group = group.sort_values('date').reset_index(drop=True)

        current_state = False
        start_date = None
        prev_date = None

        for i, row in group.iterrows():
            if row[health_state_column] == 1:
                if not current_state:
                    # Start new period
                    current_state = True
                    start_date = row['date']
                elif prev_date is not None and (row['date'] - prev_date).days > 1:
                    # Gap detected → close previous period
                    end_date = prev_date
                    duration = (end_date - start_date).days + 1
                    durations.append({
                        'cow': cow_id,
                        'health_state': health_state_column,
                        'start_date': start_date,
                        'end_date': end_date,
                        'duration_days': duration
                    })
                    # Start new period
                    start_date = row['date']
            else:
                if current_state:
                    end_date = prev_date if prev_date is not None else row['date']
                    duration = (end_date - start_date).days + 1
                    durations.append({
                        'cow': cow_id,
                        'health_state': health_state_column,
                        'start_date': start_date,
                        'end_date': end_date,
                        'duration_days': duration
                    })
                    current_state = False
                    start_date = None

            prev_date = row['date']

        # Handle case where last rows are 1s
        if current_state:
            end_date = prev_date
            duration = (end_date - start_date).days + 1
            durations.append({
                'cow': cow_id,
                'health_state': health_state_column,
                'start_date': start_date,
                'end_date': end_date,
                'duration_days': duration
            })

    return pd.DataFrame(durations)

# List of health states
health_states = [
    'oestrus', 'calving', 'lameness', 'mastitis', 'LPS',
    'acidosis', 'other_disease', 'accidents', 'disturbance',
    'mixing', 'management_changes'
]

# Dictionaries to store durations
state_durations = {}
durations_min_max = {}

# Compute durations for each health state
for state in health_states:
    durations_df = get_health_state_durations(df, state)
    state_durations[state] = durations_df

    if not durations_df.empty:
        min_duration = durations_df['duration_days'].min()
        max_duration = durations_df['duration_days'].max()
        durations_min_max[state] = {'min_days': min_duration, 'max_days': max_duration}
    else:
        durations_min_max[state] = {'min_days': None, 'max_days': None}

# Display results
for state, durations in durations_min_max.items():
    min_days = durations['min_days']
    max_days = durations['max_days']

    if min_days is None or max_days is None:
        print(f"No data available for \"{state}\".")
    elif min_days == max_days:
        print(f"{state.capitalize()} lasts {min_days} day{'s' if min_days > 1 else ''}.")
    else:
        print(f"{state.capitalize()} lasts between {min_days} and {max_days} days.")

Oestrus lasts 1 day.
Calving lasts 1 day.
Lameness lasts 1 day.
Mastitis lasts between 2 and 3 days.
Lps lasts 1 day.
No data available for "acidosis".
Other_disease lasts 1 day.
No data available for "accidents".
Disturbance lasts between 1 and 2 days.
Mixing lasts between 1 and 2 days.
No data available for "management_changes".


**Etalement des jours**

In [None]:
# import pandas as pd
# from datetime import timedelta

# # Reload the dataset
# file_path = 'dataset3-1 (1).csv'
# df = pd.read_csv(file_path)

# # Step 1: Aggregate per cow-date (no hour) to detect events
# agg_cols = ['oestrus', 'calving', 'lameness', 'mastitis', 'other_disease', 'accidents', 'disturbance', 'mixing', 'management_changes']
# daily = df.groupby(['cow', 'date'])[agg_cols].max().reset_index()

# # Convert date to datetime
# daily['date'] = pd.to_datetime(daily['date'])

# # Step 2: Build a dataframe with all possible cow x day
# all_dates = pd.date_range(daily['date'].min() - timedelta(days=7), daily['date'].max() + timedelta(days=7))
# cows = daily['cow'].unique()

# # Create full cow-date combination
# full_daily = pd.MultiIndex.from_product([cows, all_dates], names=['cow', 'date']).to_frame(index=False)

# # Merge recorded events
# full_daily = full_daily.merge(daily, on=['cow', 'date'], how='left')

# # Fill missing values for event columns with 0
# for col in agg_cols:
#     if col not in full_daily:
#         full_daily[col] = 0
#     else:
#         full_daily[col] = full_daily[col].fillna(0)

# # Add default label
# full_daily['LABEL'] = 'control'


# # Step 3: Detect episodes and spread labels
# conditions = ['oestrus', 'calving', 'lameness', 'mastitis', 'other_disease', 'accidents', 'disturbance', 'mixing', 'management_changes']

# # Prepare a new column to receive the aligned label
# full_daily['LABEL'] = 'control'

# # Define how many days before and after depending on the condition
# spread_rules = {
#     'oestrus': {'before': 1, 'after': 1},
#     'calving': {'before': 2, 'after': 1},
#     'lameness': {'before': 2, 'after': 1},
#     'mastitis': {'before': 2, 'after': 1},
#     'other_disease': {'before': 2, 'after': 1},
#     'accidents': {'before': 2, 'after': 1},
#     'disturbance': {'before': 0, 'after': 0},
#     'mixing': {'before': 0, 'after': 0},
#     'management_changes': {'before': 0, 'after': 0},
# }

# # for cond in conditions:
# #     sub = full_daily[full_daily[cond] == 1][['cow', 'date']].sort_values(['cow', 'date'])
# #     for cow_id in sub['cow'].unique():
# #         cow_days = sub[sub['cow'] == cow_id]['date'].sort_values()
# #         # Detect episodes
# #         episode = []
# #         prev_day = None
# #         for day in cow_days:
# #             if prev_day is None or (day - prev_day).days > 1:
# #                 # New episode starts
# #                 if episode:
# #                     # Process previous episode
# #                     min_day = min(episode)
# #                     max_day = max(episode)
# #                     spread = spread_rules[cond]
# #                     spread_days = pd.date_range(min_day - timedelta(days=spread['before']), max_day + timedelta(days=spread['after']))
# #                     mask = (full_daily['cow'] == cow_id) & (full_daily['date'].isin(spread_days))
# #                     full_daily.loc[mask & (full_daily['LABEL'] == 'control'), 'LABEL'] = cond
# #                     full_daily.loc[mask, cond] = 1  # Force the corresponding condition column to 1

# #                 episode = [day]
# #             else:
# #                 episode.append(day)
# #             prev_day = day
# #         # Process the last episode
# #         if episode:
# #             min_day = min(episode)
# #             max_day = max(episode)
# #             spread = spread_rules[cond]
# #             spread_days = pd.date_range(min_day - timedelta(days=spread['before']), max_day + timedelta(days=spread['after']))
# #             mask = (full_daily['cow'] == cow_id) & (full_daily['date'].isin(spread_days))
# #             full_daily.loc[mask & (full_daily['LABEL'] == 'control'), 'LABEL'] = cond
# for cond in conditions:
#     sub = full_daily[full_daily[cond] == 1][['cow', 'date']].sort_values(['cow', 'date'])
#     for cow_id in sub['cow'].unique():
#         cow_days = sub[sub['cow'] == cow_id]['date'].sort_values()
#         episode = []
#         prev_day = None
#         for day in cow_days:
#             if prev_day is None or (day - prev_day).days > 1:
#                 if episode:
#                     min_day = min(episode)
#                     max_day = max(episode)
#                     spread = spread_rules[cond]
#                     spread_days = pd.date_range(min_day - timedelta(days=spread['before']), max_day + timedelta(days=spread['after']))
#                     mask = (full_daily['cow'] == cow_id) & (full_daily['date'].isin(spread_days))
#                     full_daily.loc[mask & (full_daily['LABEL'] == 'control'), 'LABEL'] = cond
#                     full_daily.loc[mask, cond] = 1  # <--- THIS IS THE NEW LINE
#                 episode = [day]
#             else:
#                 episode.append(day)
#             prev_day = day
#         if episode:
#             min_day = min(episode)
#             max_day = max(episode)
#             spread = spread_rules[cond]
#             spread_days = pd.date_range(min_day - timedelta(days=spread['before']), max_day + timedelta(days=spread['after']))
#             mask = (full_daily['cow'] == cow_id) & (full_daily['date'].isin(spread_days))
#             full_daily.loc[mask & (full_daily['LABEL'] == 'control'), 'LABEL'] = cond
#             full_daily.loc[mask, cond] = 1  # <--- AND THIS


# # Step 4: Map back to hourly level
# # Merge back to original hourly dataframe
# final = df.copy()
# final['date'] = pd.to_datetime(final['date'])
# final = final.merge(full_daily[['cow', 'date', 'LABEL']], on=['cow', 'date'], how='left')

# # Display a sample
# final[['cow', 'date', 'hour', 'LABEL']].sample(100)
# # Save the final aligned dataset to a CSV
# final.to_csv(r'c:/users/lamia/Downloads/labelled&aligned_dataset.csv', index=False)


In [None]:
import pandas as pd
from datetime import timedelta

# Step 0: Load and clean the dataset
file_path = r"C:\Users\lamia\Downloads\filtered_dataset1_more_than_18_obs.csv"
df = pd.read_csv(file_path)
df.columns = df.columns.str.strip().str.lower()

# Step 1: Identify event columns dynamically
non_event_cols = ['cow', 'date', 'hour', 'in_alleys', 'rest', 'eat', 'activity_level', 'ok']
event_cols = [col for col in df.columns if col not in non_event_cols]

# Step 2: Aggregate daily events
daily = df.groupby(['cow', 'date'])[event_cols].max().reset_index()
daily['date'] = pd.to_datetime(daily['date'])

# Step 3: Create full cow x day table
all_dates = pd.date_range(daily['date'].min() - timedelta(days=7), daily['date'].max() + timedelta(days=7))
cows = daily['cow'].unique()
full_daily = pd.MultiIndex.from_product([cows, all_dates], names=['cow', 'date']).to_frame(index=False)

# Merge and fill missing
full_daily = full_daily.merge(daily, on=['cow', 'date'], how='left')
full_daily[event_cols] = full_daily[event_cols].fillna(0)

# Add LABEL and default OK
full_daily['LABEL'] = 'control'
full_daily['ok'] = 1

# Step 4: Spread rules
spread_rules = {
    'oestrus': {'before': 1, 'after': 1},
    'calving': {'before': 2, 'after': 1},
    'lameness': {'before': 2, 'after': 1},
    'mastitis': {'before': 2, 'after': 1},
    'lps': {'before': 2, 'after': 1},
    'acidosis': {'before': 2, 'after': 1},
    'other_disease': {'before': 2, 'after': 1},
    'accidents': {'before': 2, 'after': 1},
    'disturbance': {'before': 0, 'after': 0},
    'mixing': {'before': 0, 'after': 0},
    'management_changes': {'before': 0, 'after': 0},
}

for cond in event_cols:
    if cond not in spread_rules:
        continue
    sub = full_daily[full_daily[cond] == 1][['cow', 'date']].sort_values(['cow', 'date'])
    for cow_id in sub['cow'].unique():
        cow_days = sub[sub['cow'] == cow_id]['date'].sort_values()
        episode = []
        prev_day = None
        for day in cow_days:
            if prev_day is None or (day - prev_day).days > 1:
                if episode:
                    min_day = min(episode)
                    max_day = max(episode)
                    spread = spread_rules[cond]
                    spread_days = pd.date_range(min_day - timedelta(days=spread['before']), max_day + timedelta(days=spread['after']))
                    mask = (full_daily['cow'] == cow_id) & (full_daily['date'].isin(spread_days))
                    full_daily.loc[mask, cond] = 1
                    full_daily.loc[mask & (full_daily['LABEL'] == 'control'), 'LABEL'] = cond
                    full_daily.loc[mask, 'ok'] = 0
                episode = [day]
            else:
                episode.append(day)
            prev_day = day
        if episode:
            min_day = min(episode)
            max_day = max(episode)
            spread = spread_rules[cond]
            spread_days = pd.date_range(min_day - timedelta(days=spread['before']), max_day + timedelta(days=spread['after']))
            mask = (full_daily['cow'] == cow_id) & (full_daily['date'].isin(spread_days))
            full_daily.loc[mask, cond] = 1
            full_daily.loc[mask & (full_daily['LABEL'] == 'control'), 'LABEL'] = cond
            full_daily.loc[mask, 'ok'] = 0

# Step 5: Prepare df before merging (drop event columns to avoid conflict)
df['date'] = pd.to_datetime(df['date'])
df = df.drop(columns=event_cols + ['ok'], errors='ignore')  # <<< DROP these before merging

# Merge cleanly
final = df.merge(full_daily[['cow', 'date', 'LABEL'] + event_cols + ['ok']], on=['cow', 'date'], how='left')

# Step 6: If a day was labeled with an event, update hourly events if missing
for cond in event_cols:
    final.loc[(final['LABEL'] == cond) & (final[cond] == 0), cond] = 1

# Save
final.to_csv('c:/users/lamia/Downloads/labelled&aligned_dataset1.csv', index=False)

# Show sample
#final[['cow', 'date', 'hour', 'LABEL', 'oestrus']].sample(10)


In [None]:
# Step 5: Visual check of consecutive days labeling

# Choose a cow to inspect
#cow_to_check = final['cow'].sample(1).iloc[0]  # or manually set, e.g., cow_to_check = 10127
cow_to_check = 6601
# Filter that cow
cow_data = final[final['cow'] == cow_to_check][['cow', 'date', 'hour', 'LABEL']]

# Group by date (daily view)
daily_view = cow_data.groupby(['cow', 'date'])['LABEL'].agg(lambda x: x.mode()[0]).reset_index()

# Display consecutive days
print(f"Consecutive days labeling for Cow {cow_to_check}:")
display(daily_view.sort_values('date'))


Consecutive days labeling for Cow 6601:


Unnamed: 0,cow,date,LABEL
0,6601,2018-10-26,control
1,6601,2018-10-27,control
2,6601,2018-10-28,control
3,6601,2018-10-29,control
4,6601,2018-10-30,control
...,...,...,...
163,6601,2019-04-13,control
164,6601,2019-04-14,control
165,6601,2019-04-15,control
166,6601,2019-04-16,control


In [None]:
# Step 5: Visual check of consecutive days labeling

# Choose a cow to inspect
cow_to_check = final['cow'].sample(1).iloc[0]

# Filter that cow
cow_data = final[final['cow'] == cow_to_check][['cow', 'date', 'hour', 'LABEL']]

# Group by date (daily view)
daily_view = cow_data.groupby(['cow', 'date'])['LABEL'].agg(lambda x: x.mode()[0]).reset_index()

# Display consecutive days
print(f"Consecutive days labeling for Cow {cow_to_check}:")
display(daily_view.sort_values('date'))


Consecutive days labeling for Cow 6612:


Unnamed: 0,cow,date,LABEL
0,6612,2018-10-26,control
1,6612,2018-10-27,control
2,6612,2018-10-28,control
3,6612,2018-10-29,control
4,6612,2018-10-30,control
...,...,...,...
163,6612,2019-04-13,control
164,6612,2019-04-14,control
165,6612,2019-04-15,control
166,6612,2019-04-16,control


**Investigating before imputation**

In [None]:
import pandas as pd

# Step 0: Load the dataset
file_path = r"C:\Users\lamia\Downloads\labelled&aligned_dataset1.csv"
df = pd.read_csv(file_path)
df.columns = df.columns.str.strip().str.lower()

# Step 1: Define useful and unwanted classes
useful_classes = ['mastitis', 'lameness', 'oestrus', 'calving', 'other_disease', 'ok']
removed_classes = ['management_changes', 'mixing', 'disturbance', 'accidents', 'lps', 'acidosis']

# Step 2: Tag samples based on unwanted classes
df['had_removed_class'] = df[removed_classes].max(axis=1)  # 1 if any unwanted class was active

# Step 3: After dropping unwanted classes, check if any useful class remains
df['has_useful_class'] = df[useful_classes].max(axis=1)  # 1 if any useful class active

# Step 4: Identify rows needing reassignment
rows_to_reassign = df[(df['had_removed_class'] == 1) & (df['has_useful_class'] == 0)]

# Step 5: Summary statistics
total_rows = len(df)
total_to_reassign = len(rows_to_reassign)
proportion = total_to_reassign / total_rows * 100

print(f"Total rows in dataset: {total_rows}")
print(f"Rows needing reassignment: {total_to_reassign}")
print(f"Proportion needing reassignment: {proportion:.2f}%")


Total rows in dataset: 106269
Rows needing reassignment: 7361
Proportion needing reassignment: 6.93%


**Imputation using K-NN**

In [None]:
import pandas as pd
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler

# Step 0: Load the cleaned dataset
file_path = r"C:\Users\lamia\Downloads\labelled&aligned_dataset1.csv"
df = pd.read_csv(file_path)
df.columns = df.columns.str.strip().str.lower()

# Step 1: Define useful and removed classes
useful_classes = ['mastitis', 'lameness', 'oestrus', 'calving', 'other_disease', 'ok']
removed_classes = ['management_changes', 'mixing', 'disturbance', 'accidents', 'lps', 'acidosis']

# Step 2: Detect samples related to removed classes
df['needs_replacement'] = df[removed_classes].max(axis=1)  # 1 if any removed class is active

# Step 3: Create the physiological label
# Priority: if multiple physiological labels active, pick the first in useful_classes
def get_physio_label(row):
    for cond in useful_classes:
        if row[cond] == 1:
            return cond
    return None

df['physio_label'] = df.apply(get_physio_label, axis=1)

# Step 4: Split into clean vs. to-replace
to_replace = df[(df['needs_replacement'] == 1) & (df['physio_label'].isnull())].copy()
clean_physio = df[(df['needs_replacement'] == 0) & (df['physio_label'].notnull())].copy()

# # Step 5: Train K-NN on behavior features
# behavior_features = ['in_alleys', 'rest', 'eat', 'activity_level']

# X_train = clean_physio[behavior_features]
# y_train = clean_physio['physio_label']

# # Standardize features
# scaler = StandardScaler()
# X_train_scaled = scaler.fit_transform(X_train)

# # Train KNN
# knn = KNeighborsClassifier(n_neighbors=5)
# knn.fit(X_train_scaled, y_train)

# # Step 6: Predict missing labels
# X_missing = scaler.transform(to_replace[behavior_features])
# predicted_labels = knn.predict(X_missing)

# # Assign the new labels
# to_replace['physio_label'] = predicted_labels
# Step 5: Train K-NN on behavior features, EXCLUDING OK
behavior_features = ['in_alleys', 'rest', 'eat', 'activity_level']

# Only disease cases for training
disease_classes = ['mastitis', 'lameness', 'oestrus', 'calving', 'other_disease']

train_disease = clean_physio[clean_physio['physio_label'].isin(disease_classes)]

X_train = train_disease[behavior_features]
y_train = train_disease['physio_label']

# Standardize features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)

# Train KNN
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train_scaled, y_train)

# Step 6: Predict missing labels
X_missing = scaler.transform(to_replace[behavior_features])
predicted_labels = knn.predict(X_missing)

# Assign the new labels
to_replace['physio_label'] = predicted_labels

# Step 6.1: Analyze reassignment
reassignment_summary = pd.Series(predicted_labels).value_counts().reset_index()
reassignment_summary.columns = ['Physiological_Class', 'Number_of_Reassigned_Samples']

print(reassignment_summary)

# Step 7: Merge back
# For each predicted label, mark the corresponding event column to 1
for cond in useful_classes:
    to_replace.loc[to_replace['physio_label'] == cond, cond] = 1

# Clean removed columns
df_final = pd.concat([clean_physio, to_replace], axis=0)
df_final = df_final.drop(columns=removed_classes + ['needs_replacement'])

# Optional: reorder if needed
df_final = df_final.sort_values(by=['cow', 'date', 'hour']).reset_index(drop=True)

# Step 8: Save the cleaned and reconstructed dataset
df_final.to_csv('c:/users/lamia/Downloads/final2_cleaned_and_reassigned_dataset1.csv', index=False)

# Show sample
df_final[['cow', 'date', 'hour', 'physio_label']].sample(10)


  Physiological_Class  Number_of_Reassigned_Samples
0             oestrus                          5892
1             calving                           442
2       other_disease                           440
3            mastitis                           324
4            lameness                           263


Unnamed: 0,cow,date,hour,physio_label
45249,6646,2019-02-07,20,oestrus
86418,6699,2019-02-17,16,ok
35849,6638,2018-11-23,15,oestrus
90129,6701,2019-02-19,12,ok
28635,6634,2018-12-13,14,ok
4405,6610,2018-11-10,18,ok
16295,6621,2018-11-29,19,ok
9939,6612,2019-01-12,13,mastitis
30926,6634,2019-03-25,5,lameness
17102,6621,2019-01-02,10,ok


**Imputation using GAN**

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from sklearn.preprocessing import StandardScaler

# ===== Step 0: Load and preprocess dataset =====
file_path = r"C:\Users\lamia\Downloads\labelled&aligned_dataset1.csv"
df = pd.read_csv(file_path)
df.columns = df.columns.str.strip().str.lower()

useful_classes = ['mastitis', 'lameness', 'oestrus', 'calving', 'other_disease', 'ok']
removed_classes = ['management_changes', 'mixing', 'disturbance', 'accidents', 'lps', 'acidosis']

df['needs_replacement'] = df[removed_classes].max(axis=1)

def get_physio_label(row):
    for cond in useful_classes:
        if row[cond] == 1:
            return cond
    return None

df['physio_label'] = df.apply(get_physio_label, axis=1)
to_replace = df[(df['needs_replacement'] == 1) & (df['physio_label'].isnull())].copy()
clean_physio = df[(df['needs_replacement'] == 0) & (df['physio_label'].notnull())].copy()

# ===== Step 1: Prepare training data =====
behavior_features = ['in_alleys', 'rest', 'eat', 'activity_level']
disease_classes = ['mastitis', 'lameness', 'oestrus', 'calving', 'other_disease']
train_disease = clean_physio[clean_physio['physio_label'].isin(disease_classes)]

X_train = train_disease[behavior_features]
y_train = train_disease['physio_label']

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)

# ===== Step 2: Define GAN model =====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

latent_dim = 10
feature_dim = X_train_scaled.shape[1]

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 16),
            nn.ReLU(),
            nn.Linear(16, feature_dim)
        )
    def forward(self, z):
        return self.model(z)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(feature_dim, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.model(x)

G = Generator().to(device)
D = Discriminator().to(device)

criterion = nn.BCELoss()
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.001)
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.001)

# ===== Step 3: Train GAN =====
real_data = torch.tensor(X_train_scaled, dtype=torch.float32).to(device)
batch_size = 64
epochs = 1000

for epoch in range(epochs):
    # Train discriminator
    z = torch.randn(real_data.size(0), latent_dim).to(device)
    fake_data = G(z)

    real_labels = torch.ones(real_data.size(0), 1).to(device)
    fake_labels = torch.zeros(real_data.size(0), 1).to(device)

    d_loss_real = criterion(D(real_data), real_labels)
    d_loss_fake = criterion(D(fake_data.detach()), fake_labels)
    d_loss = d_loss_real + d_loss_fake

    D.zero_grad()
    d_loss.backward()
    d_optimizer.step()

    # Train generator
    z = torch.randn(real_data.size(0), latent_dim).to(device)
    fake_data = G(z)
    g_loss = criterion(D(fake_data), real_labels)

    G.zero_grad()
    g_loss.backward()
    g_optimizer.step()

    if epoch % 200 == 0:
        print(f"Epoch {epoch}, D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")

# ===== Step 4: Generate synthetic features for missing samples =====
num_missing = len(to_replace)
z = torch.randn(num_missing, latent_dim).to(device)
generated_features = G(z).detach().cpu().numpy()
generated_features = np.clip(generated_features, -3, 3)  # Optionnel

# ===== Step 5: Classify generated samples =====

generated_labels = np.random.choice(disease_classes, size=num_missing)

to_replace[behavior_features] = scaler.inverse_transform(generated_features)
to_replace['physio_label'] = generated_labels

# ===== Step 6: Analyse & merge =====
reassignment_summary = pd.Series(generated_labels).value_counts().reset_index()
reassignment_summary.columns = ['Physiological_Class', 'Number_of_Reassigned_Samples']
print(reassignment_summary)

df_final = pd.concat([clean_physio, to_replace], axis=0)
df_final = df_final.drop(columns=removed_classes + ['needs_replacement'])
df_final = df_final.sort_values(by=['cow', 'date', 'hour']).reset_index(drop=True)

df_final.to_csv(r'C:\Users\lamia\Downloads\GAN_cleaned_and_reassigned_dataset1.csv', index=False)

print(df_final[['cow', 'date', 'hour', 'physio_label']].sample(10))

Epoch 0, D Loss: 1.3360, G Loss: 0.7021
Epoch 200, D Loss: 1.0824, G Loss: 0.9755
Epoch 400, D Loss: 1.3870, G Loss: 0.7967
Epoch 600, D Loss: 1.2062, G Loss: 0.9506
Epoch 800, D Loss: 1.3789, G Loss: 0.6979
  Physiological_Class  Number_of_Reassigned_Samples
0            lameness                          1504
1             calving                          1475
2            mastitis                          1474
3             oestrus                          1461
4       other_disease                          1447
        cow        date  hour   physio_label
61512  6675  2019-03-16     9             ok
69344  6686  2019-04-07     5             ok
30151  6634  2019-02-19    21             ok
67532  6686  2019-01-15    12             ok
95595  6721  2018-12-11    14  other_disease
58703  6675  2018-11-14     5             ok
47706  6656  2018-11-29     8             ok
72278  6689  2019-02-20    14             ok
25076  6633  2019-01-02     2             ok
3800   6601  2019-04-08    13 