In [97]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import nltk
from sklearn.model_selection import train_test_split, KFold, cross_val_score
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.decomposition import NMF

In [3]:
data = pd.read_csv('un-general-debates.csv')

In [8]:
data.head()

Unnamed: 0,session,year,country,text
0,44,1989,MDV,﻿It is indeed a pleasure for me and the member...
1,44,1989,FIN,"﻿\nMay I begin by congratulating you. Sir, on ..."
2,44,1989,NER,"﻿\nMr. President, it is a particular pleasure ..."
3,44,1989,URY,﻿\nDuring the debate at the fortieth session o...
4,44,1989,ZWE,﻿I should like at the outset to express my del...


In [37]:
stops = set(nltk.corpus.stopwords.words('english'))

In [7]:
bow = CountVectorizer(stop_words=stops)
bag = bow.fit_transform(data.text)

In [9]:
X_train, X_test, y_train, y_test = train_test_split(bag, data.country, random_state=3)

In [14]:
X = X_train.toarray()

In [19]:
X.shape

(5630, 54754)

In [17]:
nmf = NMF(n_components=4)

In [20]:
W = nmf.fit_transform(X)

In [21]:
H = nmf.components_

In [22]:
words = bow.get_feature_names()
words = np.array(words)

In [28]:
words[np.argsort(H, axis=1)][:,-20:]

array([['also', 'general', 'african', 'problems', 'assembly',
        'community', 'government', 'us', 'new', 'country', 'people',
        'south', 'must', 'development', 'developing', 'africa',
        'international', 'economic', 'world', 'countries'],
       ['assembly', 'nuclear', 'council', 'one', 'organization', 'also',
        'rights', 'us', 'security', 'global', 'peace', 'general', 'new',
        'development', 'human', 'states', 'must', 'world', 'united',
        'nations'],
       ['policy', 'independence', 'would', 'arab', 'security',
        'relations', 'soviet', 'military', 'war', 'international',
        'nuclear', 'republic', 'countries', 'nations', 'peoples',
        'world', 'peace', 'united', 'people', 'states'],
       ['stability', 'assembly', 'support', 'cooperation', 'human',
        'process', 'region', 'rights', 'political', 'country',
        'economic', 'also', 'council', 'general', 'efforts',
        'development', 'community', 'peace', 'security', 'interna

In [25]:
countries = np.array(data.country)

In [27]:
countries[np.argsort(W, axis=0)[-30:,:]].T

array([['ZMB', 'PRY', 'FIN', 'CAF', 'CPV', 'EGY', 'COM', 'COG', 'CAF',
        'ARE', 'LBN', 'NAM', 'GNQ', 'PAK', 'PAK', 'DZA', 'FRA', 'UKR',
        'ZMB', 'AUT', 'AUS', 'KIR', 'CHE', 'DZA', 'CHE', 'BHS', 'NIC',
        'IND', 'LKA', 'MUS'],
       ['GBR', 'AUS', 'STP', 'SYR', 'TJK', 'GRD', 'DZA', 'GRD', 'STP',
        'MNE', 'UKR', 'BGR', 'SOM', 'NOR', 'TGO', 'FIN', 'CAN', 'MRT',
        'GTM', 'LBR', 'MUS', 'ZMB', 'ETH', 'CUB', 'ZMB', 'USA', 'COD',
        'STP', 'PRT', 'RUS'],
       ['IND', 'ZWE', 'SYR', 'HND', 'POL', 'PRY', 'BDI', 'ECU', 'PRK',
        'ARG', 'POL', 'MEX', 'UKR', 'IND', 'AND', 'DEU', 'GNB', 'DMA',
        'GNQ', 'TUN', 'GIN', 'NAM', 'NIC', 'THA', 'GUY', 'LBY', 'RUS',
        'HND', 'BGR', 'ARG'],
       ['SLV', 'IDN', 'SWZ', 'AGO', 'MNE', 'MNG', 'USA', 'ARG', 'ISR',
        'NPL', 'LUX', 'LCA', 'BTN', 'GAB', 'HTI', 'MYS', 'IDN', 'LBY',
        'SUR', 'RWA', 'VCT', 'LIE', 'EGY', 'GRD', 'GNQ', 'GNQ', 'MWI',
        'MDG', 'CHE', 'LCA']], dtype=object)

In [46]:
stops2 = stops.union(set(['general', 'assembly',
        'government', 'country', 'people',
        'international', 'world', 'countries', 'also',
        'united', 'must',
        'nations',
        'policy', 'relations', 'nations', 'peoples',
        'human', 'region', 'rights', 'political', 'country',
        'council']))

In [48]:
bow = CountVectorizer(stop_words=stops2)
bag = bow.fit_transform(data.text)

In [49]:
X_train, X_test, y_train, y_test = train_test_split(bag, data.country, random_state=3)
X = X_train.toarray()

In [50]:
nmf = NMF(n_components=4)
W = nmf.fit_transform(X)
H = nmf.components_

In [51]:
words = bow.get_feature_names()
words = np.array(words)
words[np.argsort(H, axis=1)][:,-20:]

array([['conference', 'economic', 'operation', 'europe', 'one', 'arms',
        'co', 'disarmament', 'new', 'union', 'military', 'war', 'would',
        'security', 'peace', 'republic', 'weapons', 'soviet', 'nuclear',
        'states'],
       ['one', 'work', 'developing', 'year', 'need', 'process', 'social',
        'states', 'support', 'organization', 'cooperation', 'efforts',
        'us', 'community', 'global', 'economic', 'peace', 'new',
        'security', 'development'],
       ['one', 'order', 'independence', 'delegation', 'operation', 'us',
        'peace', 'new', 'situation', 'problems', 'community', 'co',
        'development', 'session', 'developing', 'organization',
        'african', 'south', 'economic', 'africa'],
       ['right', 'palestine', 'iraq', 'organization', 'community',
        'would', 'lebanon', 'war', 'israeli', 'middle', 'resolutions',
        'states', 'state', 'east', 'efforts', 'palestinian', 'arab',
        'israel', 'security', 'peace']], dtype='<U23')

In [52]:
countries[np.argsort(W, axis=0)[-30:,:]].T

array([['HND', 'DEU', 'SUR', 'AUS', 'HRV', 'MEX', 'NPL', 'KEN', 'IND',
        'GHA', 'NAM', 'AFG', 'ARG', 'DEU', 'ARG', 'GNB', 'BDI', 'PRK',
        'GNQ', 'IND', 'BGR', 'ARG', 'THA', 'TUN', 'GIN', 'LBY', 'NIC',
        'GUY', 'RUS', 'HND'],
       ['BGR', 'ZMB', 'LKA', 'GUY', 'PRY', 'NOR', 'GAB', 'SVK', 'YDYE',
        'SOM', 'IRL', 'NGA', 'STP', 'GAB', 'YEM', 'RUS', 'DOM', 'YUG',
        'VAT', 'BRA', 'IRN', 'IRQ', 'POL', 'GBR', 'KWT', 'MKD', 'NER',
        'DMA', 'KWT', 'GRD'],
       ['ROU', 'BFA', 'LCA', 'BTN', 'UGA', 'PER', 'MNG', 'VNM', 'DEU',
        'CHL', 'VEN', 'QAT', 'NAM', 'LAO', 'YUG', 'NGA', 'LKA', 'IND',
        'HND', 'TUR', 'ARE', 'NIC', 'CHE', 'KIR', 'VNM', 'DEU', 'ETH',
        'AUS', 'MUS', 'MUS'],
       ['BRN', 'SLE', 'RUS', 'SVK', 'SLV', 'FJI', 'MOZ', 'CIV', 'PHL',
        'URY', 'ITA', 'IDN', 'TJK', 'DOM', 'DMA', 'FIN', 'QAT', 'GNQ',
        'MEX', 'YEM', 'LBY', 'ZWE', 'COM', 'MDV', 'PRY', 'BLZ', 'ZMB',
        'TUR', 'EGY', 'MEX']], dtype=object)

In [54]:
np.argsort(W, axis=0)

array([[2814, 2814, 2581, 2361],
       [ 754, 2357, 4250, 2186],
       [1981, 2338, 1665, 2176],
       ...,
       [4675, 5054, 1006, 3942],
       [5066, 1611, 4533, 1925],
       [3337, 1500, 2021, 2120]])

In [56]:
np.sort(W, axis=0)

array([[0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        ],
       ...,
       [3.87125617, 1.78559766, 2.85631664, 3.4395753 ],
       [4.06599555, 2.12263286, 2.90593081, 3.50794053],
       [4.31557433, 2.21749968, 3.13738628, 3.89855217]])

In [57]:
W

array([[0.36590932, 0.2622573 , 0.76940277, 0.23348519],
       [0.14519882, 0.85221998, 0.00616481, 0.07914887],
       [0.10192813, 0.75230255, 0.03070059, 0.31388216],
       ...,
       [0.36230902, 0.22981603, 1.69103623, 0.85745633],
       [0.24095471, 0.48875515, 0.        , 0.        ],
       [0.20259112, 0.        , 1.04166087, 0.59106358]])

In [61]:
len(W)/4

1407.5

In [63]:
top_countries = pd.DataFrame(countries[np.argsort(W, axis=0)[-1407:,:]])

In [83]:
topic0 = top_countries[0].groupby(by=top_countries[0]).count()

In [84]:
topic1 = top_countries[1].groupby(by=top_countries[1]).count()

In [85]:
topic2 = top_countries[2].groupby(by=top_countries[2]).count()

In [86]:
topic3 = top_countries[3].groupby(by=top_countries[3]).count()

In [91]:
topic0[topic0 > 12]

0
BEL    13
DEU    13
GTM    15
LKA    14
MEX    15
MOZ    13
PHL    13
RUS    13
STP    15
TUN    13
TUR    13
URY    13
Name: 0, dtype: int64

In [92]:
topic1[topic1 > 12]

1
BEL    13
BEN    13
CMR    13
COD    13
CRI    13
IRN    18
LBY    15
MUS    13
NGA    13
PRY    13
RUS    14
Name: 1, dtype: int64

In [94]:
topic2[topic2 > 12]

2
EGY    16
GBR    13
LKA    17
MDV    13
MNG    14
TCD    13
USA    14
Name: 2, dtype: int64

In [95]:
topic3[topic3 > 12]

3
BFA    13
IRQ    17
ISL    14
LAO    13
RUS    13
RWA    13
Name: 3, dtype: int64