In [1]:
import pandas as pd
import numpy as np
from typing import Tuple, List
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

In [2]:
def compute_directional_dtw(google_trends: np.ndarray, disease: np.ndarray, radius: int) -> Tuple[float, List]:
    """
    Compute DTW from Google Trends perspective
    """
    # Normalize series to [0,1] range
    gt_norm = (google_trends - np.min(google_trends)) / (np.max(google_trends) - np.min(google_trends))
    disease_norm = (disease - np.min(disease)) / (np.max(disease) - np.min(disease))
    
    n = len(gt_norm)  # Length of Google Trends series
    m = len(disease_norm)  # Length of disease series
    
    # Initialize cost and path matrices
    D = np.full((n, m), np.inf)
    paths = [[[] for _ in range(m)] for _ in range(n)]
    
    # For each point in Google Trends
    for i in range(n):
        # Find valid range in disease series within radius
        start_j = max(0, i - radius)
        end_j = min(m, i + radius + 1)
        
        # Find all matches within radius
        for j in range(start_j, end_j):
            cost = abs(gt_norm[i] - disease_norm[j])
            if i == 0:
                D[i, j] = cost
                paths[i][j] = [(i, j)]
            else:
                # Get minimum cost from previous Google Trends point
                prev_costs = D[i-1, max(0, j-radius):min(m, j+radius+1)]
                min_prev_cost = np.min(prev_costs) if len(prev_costs) > 0 else np.inf
                if min_prev_cost != np.inf:
                    D[i, j] = cost + min_prev_cost
                    # Get path from previous point with minimum cost
                    prev_j = max(0, j-radius) + np.argmin(prev_costs)
                    paths[i][j] = paths[i-1][prev_j] + [(i, j)]
    
    # Find best end point
    final_row = D[n-1, :]
    best_end = np.argmin(final_row)
    best_cost = final_row[best_end]
    best_path = paths[n-1][best_end]
    
    return best_cost, best_path

In [3]:
def plot_directional_dtw(google_trends: np.ndarray, disease: np.ndarray, dates: np.ndarray,
                        tag: str, metric_name: str, source: str, radius: int) -> None:
    """
    Args:
        google_trends: Google Trends time series (red line)
        disease: Disease cases time series (blue line, already filtered to match dates)
        dates: Array of dates matching the Google Trends timeline
        tag: Name of the Google Trends tag
        metric_name: Name of disease metric (Confirmed/Active Cases)
        source: Data source (MSV/RSV)
        radius: Temporal radius for matching
    """
    # Normalize series
    gt_norm = (google_trends - np.min(google_trends)) / (np.max(google_trends) - np.min(google_trends))
    disease_norm = (disease - np.min(disease)) / (np.max(disease) - np.min(disease))
    
    # Compute DTW
    distance, path = compute_directional_dtw(google_trends, disease, radius)
    
    # Create figure
    fig, ax = plt.subplots(figsize=(20, 10))
    
    # Plot time series
    ax.plot(dates, disease_norm + 1.5, label=f'{metric_name}', color='blue', linewidth=2)
    ax.plot(dates, gt_norm, label=f'{tag} ({source})', color='red', linewidth=2)
    
    # Draw matching lines every 2 days
    path = np.array(path)
    for idx, (i, j) in enumerate(path):
        if i % 2 == 0:
            ax.plot([dates[i], dates[j]], [gt_norm[i], disease_norm[j] + 1.5], 
                  'gray', alpha=0.9, linestyle='--')
    
    # Convert dates to pandas Timestamp for string formatting
    start_date = pd.Timestamp(dates[0]).strftime('%Y-%m-%d')
    end_date = pd.Timestamp(dates[-1]).strftime('%Y-%m-%d')
    timeline_info = f"{start_date} to {end_date}"
    
    ax.set_title(f'Directional DTW Alignment (±{radius} days): {tag} ({source}) vs {metric_name}', 
                 fontsize=25, pad=20)
    ax.legend(fontsize=13, loc='upper right')
    
    # Remove y-axis ticks and labels
    ax.set_yticks([])
    ax.set_ylabel('')
    
    # Format x-axis
    formatter = mdates.DateFormatter('%B %Y')
    ax.xaxis.set_major_formatter(formatter)
    ax.xaxis.set_major_locator(mdates.MonthLocator(interval=1))
    ax.tick_params(axis='x', labelsize=15)
    plt.setp(ax.get_xticklabels(), rotation=45, ha='right')

    date_range = dates[-1] - dates[0]
    padding = date_range * 0.05
    ax.set_xlim(dates[0] - padding, dates[-1] + padding)

    plt.tight_layout()
    
    filename = f"dtw_directional_{source.lower()}_{tag}_{metric_name.replace(' ', '_')}_r{radius}.png"
    plt.savefig(filename, bbox_inches='tight', dpi=300)
    plt.close()

In [4]:
def analyze_tags_directional(data_df: pd.DataFrame, disease_df: pd.DataFrame, 
                           tags: List[str], source: str, metric_name: str, 
                           radii: List[int]) -> pd.DataFrame:
    """
    Analyze tags using directional DTW with aligned timeframes
    """
    results = []
    
    # Get timeframe from Google Trends data
    start_date = data_df['date'].min()
    end_date = data_df['date'].max()
    
    # Filter disease data to match Google Trends timeframe
    disease_filtered = disease_df[(disease_df['date'] >= start_date) & 
                                (disease_df['date'] <= end_date)].copy()
    
    # Ensure date alignment
    merged_df = pd.merge(data_df, disease_filtered, on='date', how='inner')
    
    for tag in tags:
        print(f"\nProcessing {source} tag: {tag}")
        
        # Get aligned series
        gt_series = merged_df[tag].values
        disease_series = merged_df.iloc[:, -1].values
        dates = merged_df['date'].values
        
        # Compute DTW for each radius
        for radius in radii:
            try:
                dtw_score, path = compute_directional_dtw(
                    gt_series, disease_series, radius
                )
                
                results.append({
                    'source': source,
                    'tag': tag,
                    'comparison': metric_name,
                    'radius': radius,
                    'dtw_score': dtw_score
                })
                
                # Generate visualization
                plot_directional_dtw(gt_series, disease_series, dates,
                                   tag, metric_name, source, radius)
                
            except Exception as e:
                print(f"Error processing {tag} with radius {radius}: {str(e)}")
                import traceback
                print(traceback.format_exc())
    
    return pd.DataFrame(results)

In [5]:
msv_df = pd.read_csv("../gt_preprocessed_data/gt_msv_stitched/3_gt_msv_stitched_compute.csv")
rsv_df = pd.read_csv("../gt_preprocessed_data/gt_rsv_stitched/3_gt_rescaled_rsv.csv")
confirmed_df = pd.read_csv("../gt_stat_analysis/disease_confirmed_daily_cases.csv")
active_df = pd.read_csv("../gt_stat_analysis/disease_active_cases.csv")

In [6]:
for df in [msv_df, rsv_df, confirmed_df, active_df]:
    df['date'] = pd.to_datetime(df['date'])

msv_tags = [col for col in msv_df.columns if col != 'date']
rsv_tags = [col for col in rsv_df.columns if col != 'date']

In [7]:
radii = [7, 15, 20, 30, 50]

In [8]:
print("Processing MSV vs Confirmed Cases")
results_msv_confirmed = analyze_tags_directional(msv_df, confirmed_df, msv_tags, 
                                               'MSV', 'Confirmed Cases', radii)

Processing MSV vs Confirmed Cases

Processing MSV tag: flu

Processing MSV tag: cough

Processing MSV tag: fever

Processing MSV tag: headache

Processing MSV tag: lagnat

Processing MSV tag: rashes

Processing MSV tag: sipon

Processing MSV tag: ubo

Processing MSV tag: ecq

Processing MSV tag: face shield

Processing MSV tag: Frontliners

Processing MSV tag: masks

Processing MSV tag: Quarantine

Processing MSV tag: social distancing

Processing MSV tag: work from home


In [9]:
print("\nProcessing RSV vs Confirmed Cases")
results_rsv_confirmed = analyze_tags_directional(rsv_df, confirmed_df, rsv_tags, 
                                               'RSV', 'Confirmed Cases', radii)


Processing RSV vs Confirmed Cases

Processing RSV tag: flu

Processing RSV tag: cough

Processing RSV tag: fever

Processing RSV tag: headache

Processing RSV tag: lagnat

Processing RSV tag: rashes

Processing RSV tag: sipon

Processing RSV tag: ubo

Processing RSV tag: ecq

Processing RSV tag: face shield

Processing RSV tag: Frontliners

Processing RSV tag: masks

Processing RSV tag: Quarantine

Processing RSV tag: social distancing

Processing RSV tag: work from home


In [10]:
print("\nProcessing MSV vs Active Cases")
results_msv_active = analyze_tags_directional(msv_df, active_df, msv_tags, 
                                            'MSV', 'Active Cases', radii)


Processing MSV vs Active Cases

Processing MSV tag: flu

Processing MSV tag: cough

Processing MSV tag: fever

Processing MSV tag: headache

Processing MSV tag: lagnat

Processing MSV tag: rashes

Processing MSV tag: sipon

Processing MSV tag: ubo

Processing MSV tag: ecq

Processing MSV tag: face shield

Processing MSV tag: Frontliners

Processing MSV tag: masks

Processing MSV tag: Quarantine

Processing MSV tag: social distancing

Processing MSV tag: work from home


In [14]:
print("\nProcessing RSV vs Confirmed Cases")
results_rsv_confirmed = analyze_tags_directional(rsv_df, confirmed_df, rsv_tags, 
                                               'RSV', 'Confirmed Cases', radii)


Processing RSV vs Confirmed Cases

Processing RSV tag: flu

Processing RSV tag: cough

Processing RSV tag: fever

Processing RSV tag: headache

Processing RSV tag: lagnat

Processing RSV tag: rashes

Processing RSV tag: sipon

Processing RSV tag: ubo

Processing RSV tag: ecq

Processing RSV tag: face shield

Processing RSV tag: Frontliners

Processing RSV tag: masks

Processing RSV tag: Quarantine

Processing RSV tag: social distancing

Processing RSV tag: work from home


In [15]:
print("\nProcessing RSV vs Active Cases")
results_rsv_active = analyze_tags_directional(rsv_df, active_df, rsv_tags, 
                                            'RSV', 'Active Cases', radii)


Processing RSV vs Active Cases

Processing RSV tag: flu

Processing RSV tag: cough

Processing RSV tag: fever

Processing RSV tag: headache

Processing RSV tag: lagnat

Processing RSV tag: rashes

Processing RSV tag: sipon

Processing RSV tag: ubo

Processing RSV tag: ecq

Processing RSV tag: face shield

Processing RSV tag: Frontliners

Processing RSV tag: masks

Processing RSV tag: Quarantine

Processing RSV tag: social distancing

Processing RSV tag: work from home


In [16]:
all_results = pd.concat([
    results_msv_confirmed,
    results_msv_active,
    results_rsv_confirmed,
    results_rsv_active
])

In [17]:
# Sort columns with dtw_score at end
columns = [col for col in all_results.columns if col != 'dtw_score'] + ['dtw_score']
all_results = all_results[columns]

# Save results sorted by DTW score
results_sorted = all_results.sort_values('dtw_score', ascending=True)
results_sorted.to_csv('tags_directional_dtw_results.csv', index=False)

In [19]:
print("\nTop 10 matches overall (lowest DTW scores):")
print(results_sorted.head(10)[['source', 'tag', 'comparison', 'radius', 'dtw_score']])


Top 10 matches overall (lowest DTW scores):
   source          tag       comparison  radius  dtw_score
49    RSV  face shield  Confirmed Cases      50   2.027824
48    RSV  face shield  Confirmed Cases      30   3.796949
47    RSV  face shield  Confirmed Cases      20   5.581157
4     RSV          flu  Confirmed Cases      50   6.583811
46    RSV  face shield  Confirmed Cases      15   7.054873
3     RSV          flu  Confirmed Cases      30   9.062418
2     RSV          flu  Confirmed Cases      20  11.425761
45    RSV  face shield  Confirmed Cases       7  12.255652
44    RSV          ecq  Confirmed Cases      50  12.523408
64    RSV   Quarantine  Confirmed Cases      50  13.463491
