In [5]:
import pandas as pd
import re
import csv
from tqdm import tqdm
import torch
from transformers import pipeline, BitsAndBytesConfig
import os
import json

# --- Configuration ---
# Set the Hugging Face cache directory to a larger storage location
cache_dir = '/scratch/project_2004147/.cache/huggingface/hub'
os.environ['HF_HOME'] = cache_dir
os.makedirs(cache_dir, exist_ok=True) 

INPUT_FILENAME = 'bertopic_with_zeroshot_chatgpt/df_with_final_predictions.csv'
OUTPUT_FILENAME = 'bertopic_with_zeroshot_chatgpt/df_with_geo_predictions_mistral_v2.csv'
BACKUP_FILENAME = 'bertopic_with_zeroshot_chatgpt/df_with_geo_predictions_mistral_backup_v2.csv'
BACKUP_INTERVAL = 1000 

def extract_country_info_mistral(text, pipe):
    """
    Analyzes a tweet's text using a Mistral-based model to extract the country,
    confidence, and evidence by prompting for a JSON output.
    """
    if not isinstance(text, str) or not text.strip():
        return 'N/A', 'N/A', 'N/A'

    # A more sophisticated prompt asking for a structured JSON response.
    prompt = (
        "Analyze the following tweet and extract the country. "
        "Provide your answer as a JSON object with three keys: "
        "\"country\" (the country name), "
        "\"confidence\" (one of: \"High\", \"Moderate\", \"Low\", or \"N/A\"), and "
        "\"evidence\" (the specific word or phrase from the tweet that indicates the country). "
        "If no country is mentioned, the value for all keys should be \"N/A\".\n\n"
        f"Tweet: \"{text}\"\n\n"
        "JSON Answer:"
    )
    
    try:
        # Generate the response from the model
        outputs = pipe(prompt, max_new_tokens=50, do_sample=False, pad_token_id=pipe.tokenizer.eos_token_id)
        generated_text = outputs[0]['generated_text']
        
        # Extract just the JSON part of the answer
        json_part = generated_text[len(prompt):].strip()
        
        # Find the JSON object within the generated text
        match = re.search(r'\{.*\}', json_part, re.DOTALL)
        if not match:
            return 'Parse Error', 'Low', 'No JSON object found'

        # Parse the JSON string
        data = json.loads(match.group(0))
        
        country = data.get('country', 'N/A')
        confidence = data.get('confidence', 'N/A')
        evidence = data.get('evidence', 'N/A')

        # Handle cases where the model returns empty strings or None
        if not country or country.lower() == 'n/a':
            return 'N/A', 'N/A', 'N/A'

        return country, confidence, evidence

    except json.JSONDecodeError:
        return 'Parse Error', 'Low', f"Invalid JSON: {json_part}"
    except Exception as e:
        print(f"An error occurred during model inference: {e}")
        return 'Error', 'Error', str(e)


def main():
    """
    Main function to load, process, and save the data.
    """
    print(f"Hugging Face cache directory set to: {os.environ.get('HF_HOME')}")
    print("Loading Mistral location extractor model (boods/mistral-location-extractor-4bit)...")
    
    try:
        # The model is already 4-bit, so no need for quantization_config
        pipe = pipeline(
            "text-generation",
            model="boods/mistral-location-extractor-4bit",
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
    except Exception as e:
        print(f"\nError loading the Hugging Face model: {e}")
        print("Please ensure you have run '!pip install transformers torch bitsandbytes accelerate' and have a stable internet connection.")
        return

    start_index = 0
    # **FIXED**: Check if a backup file exists to resume processing.
    if os.path.exists(BACKUP_FILENAME):
        print(f"Backup file found. Loading from '{BACKUP_FILENAME}' to resume.")
        df = pd.read_csv(BACKUP_FILENAME, engine='python', on_bad_lines='warn', quoting=csv.QUOTE_MINIMAL, keep_default_na=False)
        # Find the first row that is still marked as UNPROCESSED
        unprocessed_rows = df[df['Confidence'] == 'UNPROCESSED'].index
        if not unprocessed_rows.empty:
            start_index = unprocessed_rows.min()
        else:
            # If no unprocessed rows are found, all work is done.
            start_index = df.shape[0] 
        print(f"Resuming processing from row {start_index + 1}...")
    else:
        print(f"No backup file found. Starting from scratch with '{INPUT_FILENAME}'.")
        try:
            df = pd.read_csv(INPUT_FILENAME, engine='python', on_bad_lines='warn', quoting=csv.QUOTE_MINIMAL, keep_default_na=False)
            # **FIXED**: Initialize with a unique status to avoid ambiguity.
            df['Country'] = 'UNPROCESSED'
            df['Confidence'] = 'UNPROCESSED'
            df['Evidence'] = 'UNPROCESSED'
        except FileNotFoundError:
            print(f"Error: Input file '{INPUT_FILENAME}' not found.")
            return
        except Exception as e:
            print(f"An error occurred while reading the CSV: {e}")
            return

    print("Processing tweets with Mistral model...")
    
    # If start_index is at or beyond the end of the dataframe, we are done.
    if start_index >= df.shape[0]:
        print("All rows have already been processed. Nothing to do.")
    else:
        try:
            # Use iloc to process from the start_index and adjust tqdm.
            for index, row in tqdm(df.iloc[start_index:].iterrows(), total=df.shape[0], initial=start_index, desc="Analyzing tweets"):
                country, confidence, evidence = extract_country_info_mistral(row['text'], pipe)
                df.at[index, 'Country'] = country
                df.at[index, 'Confidence'] = confidence
                df.at[index, 'Evidence'] = evidence
                
                # The backup interval now works correctly with the resumed index
                if (index + 1) % BACKUP_INTERVAL == 0:
                    print(f"\nProcessed {index + 1} rows. Saving backup...")
                    df.to_csv(BACKUP_FILENAME, index=False, quoting=csv.QUOTE_ALL, na_rep='N/A')

        except KeyboardInterrupt:
            print("\nProcess interrupted. Saving partial results...")

    # Final save and summary
    print(f"\nSaving final data to '{OUTPUT_FILENAME}'...")
    try:
        df.to_csv(OUTPUT_FILENAME, index=False, quoting=csv.QUOTE_ALL, na_rep='N/A')
        print("Processing complete!")
        
        # Count found countries by excluding initial and error states.
        found_df = df[~df['Country'].isin(['N/A', 'UNPROCESSED', 'Error', 'Parse Error'])]
        print(f"\nFound countries in {len(found_df)} rows.")
        if not found_df.empty:
            print("Country distribution:\n", found_df['Country'].value_counts())
    except Exception as e:
        print(f"An error occurred while saving the final CSV: {e}")


if __name__ == "__main__":
    main()


Hugging Face cache directory set to: /scratch/project_2004147/.cache/huggingface/hub
Loading Mistral location extractor model (boods/mistral-location-extractor-4bit)...


config.json: 0.00B [00:00, ?B/s]



model.safetensors:   0%|          | 0.00/4.45G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]



tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

Device set to use cuda:0


No backup file found. Starting from scratch with '/projappl/project_2004147/visions/bertopic_with_zeroshot_chatgpt/df_with_final_predictions.csv'.
Processing tweets with Mistral model...


Analyzing tweets:   0%|          | 0/432 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Analyzing tweets:   0%|          | 1/432 [00:05<37:23,  5.21s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Analyzing tweets:   0%|          | 2/432 [00:07<23:29,  3.28s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Analyzing tweets:   1%|          | 3/432 [00:09<18:49,  2.63s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Analyzing tweets:   1%|          | 4/432 [00:10<16:24,  2.30s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Analyzing twe


Saving final data to '/projappl/project_2004147/visions/bertopic_with_zeroshot_chatgpt/df_with_geo_predictions_mistral_v2.csv'...
Processing complete!

Found countries in 351 rows.
Country distribution:
 Country
Norway                    23
Finland                   21
Sweden                    19
Spain                     18
Italy                     17
Ireland                   17
Lithuania                 16
Latvia                    15
Kenya                     14
Brazil                    14
Poland                    14
Germany                   13
India                     13
Canada                    12
Netherlands               12
Estonia                   11
Ukraine                   11
Portugal                  11
Australia                 10
UK                        10
Wales                      9
France                     7
Iceland                    6
United Kingdom             6
USA                        6
Belgium                    4
Denmark                    4
Aust