In [None]:
import psycopg2
import csv
from dotenv import load_dotenv
import os
import pandas as pd

# Load environment variables
load_dotenv()

def connect_to_db() -> psycopg2.extensions.connection:
    return psycopg2.connect(
        host=os.getenv("DB_HOST"),
        port=os.getenv("DB_PORT"),
        database=os.getenv("DB_NAME"),
        user=os.getenv("DB_USER"),
        password=os.getenv("DB_PASSWORD"),
    )

def read_gnd_mappings(file_path: str) -> tuple:
    """Read GND mappings from CSV file and return unique values"""
    print("Reading CSV file...")
    # Read CSV and immediately drop duplicates
    df = pd.read_csv(file_path).drop_duplicates()
    
    # Get unique doc_idns
    doc_idns = df['doc_idn'].unique().tolist()
    
    # Get unique GND entities
    gnd_entities = df[['gnd_idn', 'gnd_label']].drop_duplicates().values.tolist()
    
    # Get mappings (now without duplicates)
    mappings = df[['doc_idn', 'gnd_idn']].values.tolist()
    
    print(f"Found {len(doc_idns)} unique documents")
    print(f"Found {len(gnd_entities)} unique GND entities")
    print(f"Found {len(mappings)} total mappings (after removing duplicates)")
    
    return doc_idns, gnd_entities, mappings

def create_gnd_tables(cur: psycopg2.extensions.cursor) -> None:
    """Create GND tables with proper constraints"""
    print("Dropping existing tables...")
    cur.execute("""
        DROP TABLE IF EXISTS dnb_gnd_mappings;
        DROP TABLE IF EXISTS dnb_gnd_entities;
        DROP TABLE IF EXISTS dnb_records_gnd;
    """)
    
    print("Creating new tables...")
    # Create records subset table
    cur.execute("""
        CREATE TABLE dnb_records_gnd AS 
        SELECT * FROM dnb_records WHERE FALSE;
        ALTER TABLE dnb_records_gnd ADD PRIMARY KEY (id);
        CREATE INDEX idx_records_gnd_idn ON dnb_records_gnd(idn);
    """)
    
    # Create GND entities table
    cur.execute("""
        CREATE TABLE dnb_gnd_entities (
            id SERIAL PRIMARY KEY,
            gnd_idn VARCHAR NOT NULL UNIQUE,
            gnd_label TEXT NOT NULL
        );
        CREATE INDEX idx_gnd_entities_idn ON dnb_gnd_entities(gnd_idn);
    """)
    
    # Create mappings table
    cur.execute("""
        CREATE TABLE dnb_gnd_mappings (
            record_id INTEGER REFERENCES dnb_records_gnd(id),
            gnd_id INTEGER REFERENCES dnb_gnd_entities(id),
            PRIMARY KEY (record_id, gnd_id)
        );
    """)

def populate_tables(cur: psycopg2.extensions.cursor, doc_idns: list, gnd_entities: list, mappings: list) -> None:
    """Populate all tables efficiently using batch operations"""
    # Check if we need to populate GND entities
    cur.execute("SELECT COUNT(*) FROM dnb_gnd_entities")
    current_count = cur.fetchone()[0]
    
    if current_count != len(gnd_entities):
        print("Inserting GND entities...")
        # Batch insert GND entities in chunks
        chunk_size = 1000
        for i in range(0, len(gnd_entities), chunk_size):
            chunk = gnd_entities[i:i + chunk_size]
            cur.executemany("""
                INSERT INTO dnb_gnd_entities (gnd_idn, gnd_label) 
                VALUES (%s, %s) 
                ON CONFLICT (gnd_idn) DO NOTHING
            """, chunk)
            print(f"Processed {min(i + chunk_size, len(gnd_entities))}/{len(gnd_entities)} entities")
            # Commit after each chunk to save progress
            cur.connection.commit()
    else:
        print("GND entities already populated, skipping...")

    print("Copying records...")
    cur.execute("""
        INSERT INTO dnb_records_gnd 
        SELECT * FROM dnb_records 
        WHERE idn = ANY(%s)
        ON CONFLICT DO NOTHING
    """, (doc_idns,))
    
    print("Creating mappings...")
    # First create a temporary table for the mappings
    cur.execute("""
        CREATE TEMP TABLE temp_mappings (
            doc_idn VARCHAR,
            gnd_idn VARCHAR
        )
    """)
    
    # Insert into temp table
    cur.executemany("""
        INSERT INTO temp_mappings (doc_idn, gnd_idn)
        VALUES (%s, %s)
    """, mappings)
    
    # Insert from temp table, eliminating duplicates
    cur.execute("""
        INSERT INTO dnb_gnd_mappings (record_id, gnd_id)
        SELECT DISTINCT r.id, e.id
        FROM temp_mappings m
        JOIN dnb_records_gnd r ON r.idn = m.doc_idn::varchar
        JOIN dnb_gnd_entities e ON e.gnd_idn = m.gnd_idn::varchar
        ON CONFLICT DO NOTHING
    """)
    
    # Clean up
    cur.execute("DROP TABLE temp_mappings")

def main():
    print("Starting process...")
    # Read GND mappings
    doc_idns, gnd_entities, mappings = read_gnd_mappings("../data/ger_open_access_with_gnd.csv")
    
    # Connect to database
    print("Connecting to database...")
    conn = connect_to_db()
    cur = conn.cursor()
    
    try:
        create_gnd_tables(cur)
        populate_tables(cur, doc_idns, gnd_entities, mappings)
        
        # Print statistics
        print("\nFinal Statistics:")
        cur.execute("SELECT COUNT(*) FROM dnb_records_gnd")
        print(f"Records in subset: {cur.fetchone()[0]}")
        
        cur.execute("SELECT COUNT(*) FROM dnb_gnd_entities")
        print(f"GND entities: {cur.fetchone()[0]}")
        
        cur.execute("SELECT COUNT(*) FROM dnb_gnd_mappings")
        print(f"Total mappings: {cur.fetchone()[0]}")
        
        conn.commit()
        print("\nDone!")
        
    except Exception as e:
        conn.rollback()
        raise e
    finally:
        cur.close()
        conn.close()

if __name__ == "__main__":
    main()