<a href="https://colab.research.google.com/github/fjadidi2001/AD_Prediction/blob/main/ADReSSo21.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import tarfile
import pandas as pd
import numpy as np
from pathlib import Path
import shutil

class ADReSSo21TranscriptExtractor:
    def __init__(self, base_path="/drive/MyDrive/Voice/"):
        self.base_path = base_path
        self.datasets = {
            'progression_train': 'ADReSSo21-progression-train.tgz',
            'progression_test': 'ADReSSo21-progression-test.tgz',
            'diagnosis_train': 'ADReSSo21-diagnosis-train.tgz'
        }
        self.extracted_path = os.path.join(base_path, "extracted")

    def extract_datasets(self):
        """Extract all tgz files to the extracted directory"""
        print("Extracting datasets...")

        # Create extraction directory
        os.makedirs(self.extracted_path, exist_ok=True)

        for dataset_name, filename in self.datasets.items():
            file_path = os.path.join(self.base_path, filename)

            if os.path.exists(file_path):
                print(f"Extracting {filename}...")
                with tarfile.open(file_path, 'r:gz') as tar:
                    tar.extractall(path=self.extracted_path)
                print(f"✓ {filename} extracted successfully")
            else:
                print(f"⚠ Warning: {filename} not found at {file_path}")

    def find_csv_files(self, directory):
        """Recursively find all CSV files in a directory"""
        csv_files = []
        for root, dirs, files in os.walk(directory):
            for file in files:
                if file.endswith('.csv'):
                    csv_files.append(os.path.join(root, file))
        return csv_files

    def read_transcript_csv(self, csv_path):
        """Read and process a single CSV transcript file"""
        try:
            # Try different encodings as CSV files might have different encodings
            encodings = ['utf-8', 'latin-1', 'cp1252', 'iso-8859-1']

            for encoding in encodings:
                try:
                    df = pd.read_csv(csv_path, encoding=encoding)
                    break
                except UnicodeDecodeError:
                    continue
            else:
                print(f"⚠ Could not read {csv_path} with any encoding")
                return None

            # Get filename without extension for ID
            file_id = os.path.splitext(os.path.basename(csv_path))[0]

            # Add file info
            df['file_id'] = file_id
            df['file_path'] = csv_path

            return df

        except Exception as e:
            print(f"Error reading {csv_path}: {str(e)}")
            return None

    def extract_progression_transcripts(self):
        """Extract transcripts from progression datasets"""
        transcripts = {
            'train': {'decline': [], 'no_decline': []},
            'test': []
        }

        # Process training data
        train_path = os.path.join(self.extracted_path, "ADReSSo21/progression/train/segmentation")

        if os.path.exists(train_path):
            # Decline cases
            decline_path = os.path.join(train_path, "decline")
            if os.path.exists(decline_path):
                csv_files = self.find_csv_files(decline_path)
                print(f"Found {len(csv_files)} CSV files in decline directory")

                for csv_file in csv_files:
                    df = self.read_transcript_csv(csv_file)
                    if df is not None:
                        df['label'] = 'decline'
                        transcripts['train']['decline'].append(df)

            # No decline cases
            no_decline_path = os.path.join(train_path, "no_decline")
            if os.path.exists(no_decline_path):
                csv_files = self.find_csv_files(no_decline_path)
                print(f"Found {len(csv_files)} CSV files in no_decline directory")

                for csv_file in csv_files:
                    df = self.read_transcript_csv(csv_file)
                    if df is not None:
                        df['label'] = 'no_decline'
                        transcripts['train']['no_decline'].append(df)

        # Process test data
        test_path = os.path.join(self.extracted_path, "ADReSSo21/progression/test-dist/segmentation")

        if os.path.exists(test_path):
            csv_files = self.find_csv_files(test_path)
            print(f"Found {len(csv_files)} CSV files in test directory")

            for csv_file in csv_files:
                df = self.read_transcript_csv(csv_file)
                if df is not None:
                    df['label'] = 'test'
                    transcripts['test'].append(df)

        return transcripts

    def extract_diagnosis_transcripts(self):
        """Extract transcripts from diagnosis dataset"""
        transcripts = {'ad': [], 'cn': []}

        base_path = os.path.join(self.extracted_path, "ADReSSo21/diagnosis/train/segmentation")

        if os.path.exists(base_path):
            # AD (Alzheimer's Disease) cases
            ad_path = os.path.join(base_path, "ad")
            if os.path.exists(ad_path):
                csv_files = self.find_csv_files(ad_path)
                print(f"Found {len(csv_files)} CSV files in AD directory")

                for csv_file in csv_files:
                    df = self.read_transcript_csv(csv_file)
                    if df is not None:
                        df['label'] = 'ad'
                        transcripts['ad'].append(df)

            # CN (Cognitive Normal) cases
            cn_path = os.path.join(base_path, "cn")
            if os.path.exists(cn_path):
                csv_files = self.find_csv_files(cn_path)
                print(f"Found {len(csv_files)} CSV files in CN directory")

                for csv_file in csv_files:
                    df = self.read_transcript_csv(csv_file)
                    if df is not None:
                        df['label'] = 'cn'
                        transcripts['cn'].append(df)

        return transcripts

    def combine_and_save_transcripts(self, transcripts, dataset_name):
        """Combine transcript dataframes and save to CSV"""
        all_transcripts = []

        if dataset_name == 'progression':
            # Combine training data
            for label in ['decline', 'no_decline']:
                if transcripts['train'][label]:
                    combined = pd.concat(transcripts['train'][label], ignore_index=True)
                    all_transcripts.append(combined)

            # Combine test data
            if transcripts['test']:
                combined_test = pd.concat(transcripts['test'], ignore_index=True)
                all_transcripts.append(combined_test)

        elif dataset_name == 'diagnosis':
            # Combine AD and CN data
            for label in ['ad', 'cn']:
                if transcripts[label]:
                    combined = pd.concat(transcripts[label], ignore_index=True)
                    all_transcripts.append(combined)

        if all_transcripts:
            final_df = pd.concat(all_transcripts, ignore_index=True)

            # Save to CSV
            output_path = os.path.join(self.base_path, f"{dataset_name}_transcripts.csv")
            final_df.to_csv(output_path, index=False)
            print(f"✓ Saved {len(final_df)} transcript records to {output_path}")

            return final_df

        return None

    def display_sample_data(self, df, dataset_name):
        """Display sample data and statistics"""
        print(f"\n=== {dataset_name.upper()} DATASET SUMMARY ===")
        print(f"Total records: {len(df)}")

        if 'label' in df.columns:
            print("\nLabel distribution:")
            print(df['label'].value_counts())

        print(f"\nColumns: {list(df.columns)}")

        print(f"\nSample data:")
        print(df.head())

        # Show some transcript samples if available
        text_columns = [col for col in df.columns if 'text' in col.lower() or 'transcript' in col.lower() or 'word' in col.lower()]
        if text_columns:
            print(f"\nSample transcript content from column '{text_columns[0]}':")
            for i, text in enumerate(df[text_columns[0]].dropna().head(3)):
                print(f"Sample {i+1}: {str(text)[:200]}...")

    def run_extraction(self):
        """Main method to run the complete extraction process"""
        print("Starting ADReSSo21 transcript extraction...")

        # Extract datasets
        self.extract_datasets()

        # Extract progression transcripts
        print("\n" + "="*50)
        print("EXTRACTING PROGRESSION TRANSCRIPTS")
        print("="*50)
        progression_transcripts = self.extract_progression_transcripts()
        progression_df = self.combine_and_save_transcripts(progression_transcripts, 'progression')

        if progression_df is not None:
            self.display_sample_data(progression_df, 'progression')

        # Extract diagnosis transcripts
        print("\n" + "="*50)
        print("EXTRACTING DIAGNOSIS TRANSCRIPTS")
        print("="*50)
        diagnosis_transcripts = self.extract_diagnosis_transcripts()
        diagnosis_df = self.combine_and_save_transcripts(diagnosis_transcripts, 'diagnosis')

        if diagnosis_df is not None:
            self.display_sample_data(diagnosis_df, 'diagnosis')

        print("\n" + "="*50)
        print("EXTRACTION COMPLETE!")
        print("="*50)

        return progression_df, diagnosis_df

# Usage
if __name__ == "__main__":
    # Initialize extractor
    extractor = ADReSSo21TranscriptExtractor()

    # Run extraction
    progression_df, diagnosis_df = extractor.run_extraction()

    # Optional: Access individual datasets
    # You can also use these methods individually:
    # extractor.extract_datasets()
    # progression_transcripts = extractor.extract_progression_transcripts()
    # diagnosis_transcripts = extractor.extract_diagnosis_transcripts()

Starting ADReSSo21 transcript extraction...
Extracting datasets...

EXTRACTING PROGRESSION TRANSCRIPTS

EXTRACTING DIAGNOSIS TRANSCRIPTS

EXTRACTION COMPLETE!
