In [1]:
import pandas as pd
import numpy as np
from typing import Tuple, List
from scipy.spatial.distance import euclidean
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import numpy.ma as ma

In [2]:
def create_temporal_constraint_mask(length: int, radius: int) -> np.ndarray:
    """
    Create a mask matrix for temporal constraints.
    Points outside the radius window will be masked with True
    """
    mask = np.zeros((length, length), dtype=bool)
    
    for i in range(length):
        for j in range(length):
            if abs(i - j) > radius:
                mask[i, j] = True
                
    return mask

In [3]:
def compute_dtw_with_temporal_constraint(series1: np.ndarray, series2: np.ndarray, radius: int) -> Tuple[float, List]:
    """
    Compute DTW with strict temporal constraints
    """
    # Normalize series to [0,1] range
    s1_norm = (series1 - np.min(series1)) / (np.max(series1) - np.min(series1))
    s2_norm = (series2 - np.min(series2)) / (np.max(series2) - np.min(series2))
    
    n, m = len(s1_norm), len(s2_norm)
    
    # Create cost matrix
    cost_matrix = np.zeros((n, m))
    for i in range(n):
        for j in range(m):
            cost_matrix[i, j] = abs(s1_norm[i] - s2_norm[j])
    
    # Create accumulated cost matrix with temporal constraint
    D = np.full((n, m), np.inf)
    D[0, 0] = cost_matrix[0, 0]
    
    # Create temporal constraint mask
    temporal_mask = create_temporal_constraint_mask(max(n, m), radius)
    temporal_mask = temporal_mask[:n, :m]
    
    # Apply temporal constraint
    for i in range(n):
        for j in range(m):
            if temporal_mask[i, j]:
                continue
                
            if i == 0 and j == 0:
                continue
                
            candidates = []
            if i > 0:
                candidates.append(D[i-1, j])
            if j > 0:
                candidates.append(D[i, j-1])
            if i > 0 and j > 0:
                candidates.append(D[i-1, j-1])
            
            if candidates:
                D[i, j] = cost_matrix[i, j] + min(candidates)
    
    # Backtrack to find the warping path
    path = []
    i, j = n-1, m-1
    path.append((i, j))
    
    while i > 0 or j > 0:
        candidates = []
        if i > 0:
            candidates.append((D[i-1, j], i-1, j))
        if j > 0:
            candidates.append((D[i, j-1], i, j-1))
        if i > 0 and j > 0:
            candidates.append((D[i-1, j-1], i-1, j-1))
            
        _, i, j = min(candidates, key=lambda x: x[0])
        path.append((i, j))
    
    path.reverse()
    
    return D[-1, -1], path

In [4]:
def plot_dtw_alignment(series1: np.ndarray, series2: np.ndarray, dates: np.ndarray,
                      tag: str, metric_name: str, source: str, radius: int) -> None:
    """
    Plot the DTW alignment between two time series showing the warping path
    """
    # Normalize series for visualization
    s1_norm = (series1 - np.min(series1)) / (np.max(series1) - np.min(series1))
    s2_norm = (series2 - np.min(series2)) / (np.max(series2) - np.min(series2))
    
    # Compute DTW with temporal constraint
    distance, path = compute_dtw_with_temporal_constraint(series1, series2, radius)
    
    # Create figure
    fig, ax = plt.subplots(figsize=(20, 10))
    
    # Plot the first time series at the top
    ax.plot(dates, s1_norm + 1.5, label=f'{metric_name}', color='blue', linewidth=2)
    
    # Plot the second time series at the bottom
    ax.plot(dates, s2_norm, label=f'{tag} ({source})', color='red', linewidth=2)
    
    # Draw matching lines between points
    path = np.array(path)
    for i, j in path[::3]:  # Plot every 3rd line to reduce visual clutter
        ax.plot([dates[i], dates[j]], [s1_norm[i] + 1.5, s2_norm[j]], 
                'gray', alpha=0.9, linestyle='--')
    
    # Customize the plot
    ax.set_title(f'DTW Alignment (±{radius} days): {tag} ({source}) vs {metric_name}', 
                 fontsize=25, pad=20)
    ax.legend(fontsize=13, loc='upper right')
    
    # Remove y-axis ticks and labels
    ax.set_yticks([])
    ax.set_ylabel('')
    
    # Format x-axis dates
    formatter = mdates.DateFormatter('%B %Y')
    ax.xaxis.set_major_formatter(formatter)
    ax.xaxis.set_major_locator(mdates.MonthLocator(interval=1))
    
    # Rotate and align the tick labels
    ax.tick_params(axis='x', labelsize=15)
    plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
    
    # Adjust layout
    plt.tight_layout()
    
    # Save the plot
    filename = f"dtw_alignment_{source.lower()}_{tag}_{metric_name.replace(' ', '_')}_r{radius}.png"
    plt.savefig(filename, bbox_inches='tight', dpi=300)
    plt.close()

In [5]:
def analyze_temporal_dtw(data_df: pd.DataFrame, disease_df: pd.DataFrame, 
                        tags: List[str], source: str, metric_name: str, 
                        radii: List[int]) -> pd.DataFrame:
    """
    Analyze DTW distances with strict temporal constraints for multiple tags
    """
    results = []
    
    # Ensure date alignment
    merged_df = pd.merge(data_df, disease_df, on='date', how='inner')
    
    for tag in tags:
        print(f"\nProcessing {source} tag: {tag}")
        
        # Get aligned series
        signal_series = merged_df[tag].values
        disease_series = merged_df.iloc[:, -1].values
        dates = merged_df['date'].values
        
        # Compute DTW for each radius
        for radius in radii:
            try:
                dtw_score, path = compute_dtw_with_temporal_constraint(
                    disease_series, signal_series, radius
                )
                
                results.append({
                    'source': source,
                    'tag': tag,
                    'comparison': metric_name,
                    'radius': radius,
                    'dtw_score': dtw_score
                })
                
                # Generate visualization
                plot_dtw_alignment(disease_series, signal_series, dates,
                                 tag, metric_name, source, radius)
                
            except Exception as e:
                print(f"Error processing {tag} with radius {radius}: {str(e)}")
                import traceback
                print(traceback.format_exc())
    
    return pd.DataFrame(results)

In [6]:
# Load datasets
msv_df = pd.read_csv("../gt_preprocessed_data/gt_msv_stitched/3_gt_msv_stitched_compute.csv")
rsv_df = pd.read_csv("../gt_preprocessed_data/gt_rsv_stitched/3_gt_rescaled_rsv.csv")
confirmed_df = pd.read_csv("../gt_stat_analysis/disease_confirmed_daily_cases.csv")
active_df = pd.read_csv("../gt_stat_analysis/disease_active_cases.csv")

In [7]:
# Convert dates to datetime
for df in [msv_df, rsv_df, confirmed_df, active_df]:
    df['date'] = pd.to_datetime(df['date'])

In [8]:
# Get tags (excluding 'date' column)
msv_tags = [col for col in msv_df.columns if col != 'date']
rsv_tags = [col for col in rsv_df.columns if col != 'date']

In [9]:
# Define radii for analysis
radii = [7, 15, 20, 30, 50]

In [10]:
# Process MSV data
print("Processing MSV vs Confirmed Cases")
results_msv_confirmed = analyze_temporal_dtw(msv_df, confirmed_df, msv_tags, 
                                           'MSV', 'Confirmed Cases', radii)

print("\nProcessing MSV vs Active Cases")
results_msv_active = analyze_temporal_dtw(msv_df, active_df, msv_tags, 
                                         'MSV', 'Active Cases', radii)

Processing MSV vs Confirmed Cases

Processing MSV tag: flu

Processing MSV tag: cough

Processing MSV tag: fever

Processing MSV tag: headache

Processing MSV tag: lagnat

Processing MSV tag: rashes

Processing MSV tag: sipon

Processing MSV tag: ubo

Processing MSV tag: ecq

Processing MSV tag: face shield

Processing MSV tag: Frontliners

Processing MSV tag: masks

Processing MSV tag: Quarantine

Processing MSV tag: social distancing

Processing MSV tag: work from home

Processing MSV vs Active Cases

Processing MSV tag: flu

Processing MSV tag: cough

Processing MSV tag: fever

Processing MSV tag: headache

Processing MSV tag: lagnat

Processing MSV tag: rashes

Processing MSV tag: sipon

Processing MSV tag: ubo

Processing MSV tag: ecq

Processing MSV tag: face shield

Processing MSV tag: Frontliners

Processing MSV tag: masks

Processing MSV tag: Quarantine

Processing MSV tag: social distancing

Processing MSV tag: work from home


In [11]:
# Process RSV data
print("\nProcessing RSV vs Confirmed Cases")
results_rsv_confirmed = analyze_temporal_dtw(rsv_df, confirmed_df, rsv_tags, 
                                           'RSV', 'Confirmed Cases', radii)

print("\nProcessing RSV vs Active Cases")
results_rsv_active = analyze_temporal_dtw(rsv_df, active_df, rsv_tags, 
                                         'RSV', 'Active Cases', radii)


Processing RSV vs Confirmed Cases

Processing RSV tag: flu

Processing RSV tag: cough

Processing RSV tag: fever

Processing RSV tag: headache

Processing RSV tag: lagnat

Processing RSV tag: rashes

Processing RSV tag: sipon

Processing RSV tag: ubo

Processing RSV tag: ecq

Processing RSV tag: face shield

Processing RSV tag: Frontliners

Processing RSV tag: masks

Processing RSV tag: Quarantine

Processing RSV tag: social distancing

Processing RSV tag: work from home

Processing RSV vs Active Cases

Processing RSV tag: flu

Processing RSV tag: cough

Processing RSV tag: fever

Processing RSV tag: headache

Processing RSV tag: lagnat

Processing RSV tag: rashes

Processing RSV tag: sipon

Processing RSV tag: ubo

Processing RSV tag: ecq

Processing RSV tag: face shield

Processing RSV tag: Frontliners

Processing RSV tag: masks

Processing RSV tag: Quarantine

Processing RSV tag: social distancing

Processing RSV tag: work from home


In [12]:
# Combine all results
all_results = pd.concat([
    results_msv_confirmed,
    results_msv_active,
    results_rsv_confirmed,
    results_rsv_active
])

In [13]:
# Save results sorted by radius and DTW score
results_sorted = all_results.sort_values(['radius', 'dtw_score'])
results_sorted.to_csv('dtw_results_multi_radius.csv', index=False)

# Save results sorted only by DTW score
results_sorted_by_score = all_results.sort_values('dtw_score', ascending=True)
results_sorted_by_score.to_csv('dtw_results_multi_radius_sorted.csv', index=False)