# Sklearn-cotraining examples

Este notebook está constituido por una serie de ejemplos de semi-supervised cotraining usando el módulo sklearn_cotraining https://github.com/jjrob13/sklearn_cotraining. La implementación sigue las ideas expuestas en "Combining labeled and unlabeled data with co-training" (Blue, Mitchell 1998) https://www.cs.cmu.edu/~avrim/Papers/cotrain.pdf. El módulo ya importa numpy as np, random y copy.

Los archivos \__init__\.py del módulo original han sido adaptados para poderse utilizar usando una instalación no local, cambiada la importación de CoTrainingClassifier, también adaptado a Python 3. El ejemplo de uso de SVM ha sido corregido, usaba base_lr.predict en lugar de base_svm.predict.

In [1]:
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.metrics import classification_report
from sklearn.datasets import make_classification
from sklearn_cotraining.classifiers import CoTrainingClassifier

In [2]:
N_SAMPLES = 25000
N_FEATURES = 1000
X, y = make_classification(n_samples=N_SAMPLES, n_features=N_FEATURES)

y[:N_SAMPLES//2] = -1

X_test = X[-N_SAMPLES//4:]
y_test = y[-N_SAMPLES//4:]

X_labeled = X[N_SAMPLES//2:-N_SAMPLES//4]
y_labeled = y[N_SAMPLES//2:-N_SAMPLES//4]

y = y[:-N_SAMPLES//4]
X = X[:-N_SAMPLES//4]


X1 = X[:,:N_FEATURES // 2]
X2 = X[:, N_FEATURES // 2:]

In [3]:
print ('Logistic')
base_lr = LogisticRegression()
base_lr.fit(X_labeled, y_labeled)
y_pred = base_lr.predict(X_test)
print (classification_report(y_test, y_pred))

Logistic
             precision    recall  f1-score   support

          0       0.87      0.87      0.87      3088
          1       0.87      0.87      0.87      3162

avg / total       0.87      0.87      0.87      6250



In [4]:
print ('Logistic CoTraining')
lg_co_clf = CoTrainingClassifier(LogisticRegression())
lg_co_clf.fit(X1, X2, y)
y_pred = lg_co_clf.predict(X_test[:, :N_FEATURES // 2], X_test[:, N_FEATURES // 2:])
print (classification_report(y_test, y_pred))

Logistic CoTraining
             precision    recall  f1-score   support

          0       0.92      0.93      0.92      3088
          1       0.93      0.92      0.93      3162

avg / total       0.93      0.93      0.93      6250



In [5]:
print ('SVM')
base_svm = LinearSVC()
base_svm.fit(X_labeled, y_labeled)
y_pred = base_svm.predict(X_test)
print (classification_report(y_test, y_pred))

SVM
             precision    recall  f1-score   support

          0       0.86      0.86      0.86      3088
          1       0.87      0.86      0.87      3162

avg / total       0.86      0.86      0.86      6250



In [6]:
print ('SVM CoTraining')
svm_co_clf = CoTrainingClassifier(LinearSVC(), u=N_SAMPLES//10)
svm_co_clf.fit(X1, X2, y)
y_pred = svm_co_clf.predict(X_test[:, :N_FEATURES // 2], X_test[:, N_FEATURES // 2:])
print (classification_report(y_test, y_pred))

SVM CoTraining
             precision    recall  f1-score   support

          0       0.91      0.91      0.91      3088
          1       0.92      0.91      0.91      3162

avg / total       0.91      0.91      0.91      6250



In [7]:
print ('SVM CoTraining Logistic')
svm_co_lg_clf = CoTrainingClassifier(LinearSVC(), clf2 = LogisticRegression(), u=N_SAMPLES//10)
svm_co_lg_clf.fit(X1, X2, y)
y_pred = svm_co_lg_clf.predict(X_test[:, :N_FEATURES // 2], X_test[:, N_FEATURES // 2:])
print (classification_report(y_test, y_pred))

SVM CoTraining Logistic
             precision    recall  f1-score   support

          0       0.91      0.91      0.91      3088
          1       0.92      0.91      0.91      3162

avg / total       0.91      0.91      0.91      6250



In [8]:
print ('Logistic CoTraining SVM')
lg_co_svm_clf = CoTrainingClassifier(LogisticRegression(), clf2 = LinearSVC())
lg_co_svm_clf.fit(X1, X2, y)
y_pred = lg_co_svm_clf.predict(X_test[:, :N_FEATURES // 2], X_test[:, N_FEATURES // 2:])
print (classification_report(y_test, y_pred))

Logistic CoTraining SVM
             precision    recall  f1-score   support

          0       0.91      0.92      0.91      3088
          1       0.92      0.91      0.92      3162

avg / total       0.91      0.91      0.91      6250

